From 148326945acb9ae8afac3507b4ed051d1b91a80a Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Tue, 1 Feb 2022 23:04:27 -0500 Subject: [PATCH] found the bug, fixed --- test/rules.jl | 34 ++++++++++++++++------------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/test/rules.jl b/test/rules.jl index fd6680a2..0b272e5d 100644 --- a/test/rules.jl +++ b/test/rules.jl @@ -201,11 +201,13 @@ end end end -@testset verbose=true "with complex numebers: Flux#1776" begin +@testset "with complex numebers: Flux#1776" begin empty!(LOG) - @testset "$(name(f(1e-2)))" for f in [ - ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief, - Descent, Momentum, Nesterov, ADAMW, # not in Flux PR + @testset "$(name(opt))" for opt in [ + # The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too: + ADAM(1e-2), RMSProp(1e-2), RADAM(1e-2), OADAM(1e-2), ADAGrad(1e-2), ADADelta(0.9, 1e-5), NADAM(1e-2), AdaBelief(1e-2), + # These weren't in Flux PR: + Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), ADAMW(1e-2), ] # Our "model" is just a complex number model = (w = zeros(ComplexF64, 1),) @@ -218,8 +220,8 @@ end # on older Flux builds that don't have a fixed `mse()` return sum(abs2.(m.w * x .- conj(x))) end + @test loss(model) ≈ 2.0 - opt = f(1e-2) state = Optimisers.setup(opt, model) # Train for 10 iterations, enforcing that loss is monotonically decreasing @@ -227,16 +229,14 @@ end for idx in 1:10 grads = loggradient(opt)(loss, model) state, model = Optimisers.update!(state, model, grads...) - if opt isa Union{Momentum, Nesterov} && idx > 8 # these are very flat at the end - @test_skip loss(model) < last_loss - else - @test loss(model) < last_loss - end + opt isa Union{Momentum, Nesterov} && idx > 8 && continue # these are very flat at the end + @test loss(model) < last_loss last_loss = loss(model) end + @test loss(model) < 1.9 # Repeat with StaticArrays - static_model = (w = SA[1.0 + 0im],) + static_model = (w = SA[0.0 + 0im],) static_state = Optimisers.setup(opt, static_model) function static_loss(m) x = hcat(SA[1.0 + im]) @@ -246,13 +246,11 @@ end for idx in 1:10 grads = gradient(static_loss, static_model) static_state, static_model = Optimisers.update!(static_state, static_model, grads...) - if opt isa Union{Momentum, Nesterov} && idx > 8 - @test_skip loss(static_model) < last_loss - else - @test loss(static_model) < last_loss - end - last_loss = loss(static_model) + opt isa Union{Momentum, Nesterov} && idx > 8 && continue + @test static_loss(static_model) < last_loss + last_loss = static_loss(static_model) end + @test static_loss(static_model) < 1.9 end end @@ -264,7 +262,7 @@ _plot!(:OADAM) _plot!(:RMSProp) _plot!(:NADAM) -_plot!(:ADADelta) # barely declines +_plot!(:ADADelta) _plot(:Momentum) _plot!(:Nesterov) # very flat at the end