Skip to content

Commit

Permalink
found the bug, fixed
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Feb 2, 2022
1 parent 309e606 commit 1483269
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,11 +201,13 @@ end
end
end

@testset verbose=true "with complex numebers: Flux#1776" begin
@testset "with complex numebers: Flux#1776" begin
empty!(LOG)
@testset "$(name(f(1e-2)))" for f in [
ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief,
Descent, Momentum, Nesterov, ADAMW, # not in Flux PR
@testset "$(name(opt))" for opt in [
# The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
ADAM(1e-2), RMSProp(1e-2), RADAM(1e-2), OADAM(1e-2), ADAGrad(1e-2), ADADelta(0.9, 1e-5), NADAM(1e-2), AdaBelief(1e-2),
# These weren't in Flux PR:
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), ADAMW(1e-2),
]
# Our "model" is just a complex number
model = (w = zeros(ComplexF64, 1),)
Expand All @@ -218,25 +220,23 @@ end
# on older Flux builds that don't have a fixed `mse()`
return sum(abs2.(m.w * x .- conj(x)))
end
@test loss(model) 2.0

opt = f(1e-2)
state = Optimisers.setup(opt, model)

# Train for 10 iterations, enforcing that loss is monotonically decreasing
last_loss = Inf
for idx in 1:10
grads = loggradient(opt)(loss, model)
state, model = Optimisers.update!(state, model, grads...)
if opt isa Union{Momentum, Nesterov} && idx > 8 # these are very flat at the end
@test_skip loss(model) < last_loss
else
@test loss(model) < last_loss
end
opt isa Union{Momentum, Nesterov} && idx > 8 && continue # these are very flat at the end
@test loss(model) < last_loss
last_loss = loss(model)
end
@test loss(model) < 1.9

# Repeat with StaticArrays
static_model = (w = SA[1.0 + 0im],)
static_model = (w = SA[0.0 + 0im],)
static_state = Optimisers.setup(opt, static_model)
function static_loss(m)
x = hcat(SA[1.0 + im])
Expand All @@ -246,13 +246,11 @@ end
for idx in 1:10
grads = gradient(static_loss, static_model)
static_state, static_model = Optimisers.update!(static_state, static_model, grads...)
if opt isa Union{Momentum, Nesterov} && idx > 8
@test_skip loss(static_model) < last_loss
else
@test loss(static_model) < last_loss
end
last_loss = loss(static_model)
opt isa Union{Momentum, Nesterov} && idx > 8 && continue
@test static_loss(static_model) < last_loss
last_loss = static_loss(static_model)
end
@test static_loss(static_model) < 1.9
end
end

Expand All @@ -264,7 +262,7 @@ _plot!(:OADAM)
_plot!(:RMSProp)
_plot!(:NADAM)
_plot!(:ADADelta) # barely declines
_plot!(:ADADelta)
_plot(:Momentum)
_plot!(:Nesterov) # very flat at the end
Expand Down

0 comments on commit 1483269

Please sign in to comment.