Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Complex numbers alla Flux 1776 #47

Merged
merged 10 commits into from
Feb 5, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 10 additions & 10 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ init(o::RMSProp, x::AbstractArray) = zero(x)
function apply!(o::RMSProp, state, x, dx)
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state

@.. acc = ρ * acc + (1 - ρ) * dx^2
@.. acc = ρ * acc + (1 - ρ) * abs2(dx)
dx′ = @lazy dx * (η / (sqrt(acc) + ϵ))

return acc, dx′
Expand Down Expand Up @@ -136,7 +136,7 @@ function apply!(o::ADAM, state, x, dx)
mt, vt, βt = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
@.. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
dx′ = @lazy mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ) * η

return (mt, vt, βt .* β), dx′
Expand Down Expand Up @@ -171,7 +171,7 @@ function apply!(o::RADAM, state, x, dx)
mt, vt, βt, t = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
if ρ > 4
r = sqrt((ρ - 4) * (ρ - 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
Expand Down Expand Up @@ -244,7 +244,7 @@ function apply!(o::OADAM, state, x, dx)
mt, vt, βt, term = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
prev = copy(term)
@.. term = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
dx′ = @lazy 2 * term - prev
Expand Down Expand Up @@ -277,7 +277,7 @@ function apply!(o::ADAGrad, state, x, dx)
η, ϵ = o.eta, o.epsilon
acc = state

@.. acc = acc + dx^2
@.. acc = acc + abs2(dx)
dx′ = @lazy dx * η / (sqrt(acc) + ϵ)

return acc, dx′
Expand Down Expand Up @@ -307,10 +307,10 @@ function apply!(o::ADADelta, state, x, dx)
ρ, ϵ = o.rho, o.epsilon
acc, Δacc = state

@.. acc = ρ * acc + (1 - ρ) * dx^2
@.. acc = ρ * acc + (1 - ρ) * abs2(dx)
# DON'T remove epsilon from numerator or even out of the square roots!
dx′ = @. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ) # Cannot be lazy as this needs the old Δacc
@.. Δacc = ρ * Δacc + (1 - ρ) * dx′^2
@.. Δacc = ρ * Δacc + (1 - ρ) * abs2(dx′)

return (acc, Δacc), dx′
end
Expand Down Expand Up @@ -344,7 +344,7 @@ function apply!(o::AMSGrad, state, x, dx)
mt, vt, v̂t = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
@.. vt = β[2] * vt + (1 - β[2]) * dx ^ 2
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
@.. v̂t = max(v̂t, vt)
dx′ = @lazy η * mt / (sqrt(v̂t) + ϵ)

Expand Down Expand Up @@ -380,7 +380,7 @@ function apply!(o::NADAM, state, x, dx)
mt, vt, βt = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
@.. vt = β[2] * vt + (1 - β[2]) * dx^2
@.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
dx′ = @lazy (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η

Expand Down Expand Up @@ -433,7 +433,7 @@ function apply!(o::AdaBelief, state, x, dx)
mt, st = state

@.. mt = β[1] * mt + (1 - β[1]) * dx
@.. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
@.. st = β[2] * st + (1 - β[2]) * abs2(dx - mt)
dx′ = @lazy η * mt / (sqrt(st) + ϵ)

return (mt, st), dx′
Expand Down
76 changes: 70 additions & 6 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,16 @@ RULES = [
name(o) = typeof(o).name.name
name(o::OptimiserChain) = join(name.(o.opts), " → ")

LOG = Dict()

loggradient(o) = (f, xs...) -> begin
y, dxs = Zygote.withgradient(f, xs...)
push!(get!(() -> Float32[], LOG, name(o)), y)
dxs # save the loss, return the gradient
end
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This (LOG) doesn't seem to be in use anywhere, is it still necessary?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It just makes debugging easier if you can plot things form the tests you just ran. It's not strictly necessary but also doesn't really get in the way, I think.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fair enough. A comment with what you just described would help then.


@testset "independence" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
w = randn(10, 10)
w′ = randn(10, 10)
Expand All @@ -28,22 +37,23 @@ name(o::OptimiserChain) = join(name.(o.opts), " → ")
st = Optimisers.setup(o, w)
for t = 1:10^5
x = rand(10)
gs = gradient(w -> iloss(x, w, w′), w)
gs = loggradient(o)(w -> iloss(x, w, w′), w)
st, w = Optimisers.update!(st, w, gs...)
end
@test iloss(rand(10, 10), w, w′) < 0.01
end
end

@testset verbose=true "simple sum" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
m = shuffle!(reshape(1:64, 8, 8) .+ 0.0)
s = Optimisers.setup(o, m)
for _ in 1:10^5
g = gradient(x -> sum(abs2, x + x'), m)[1]
g = loggradient(o)(x -> sum(abs2, x + x'), m)[1]
s, m = Optimisers.update!(s, m, g)
end
# @test sum(m) < sum(1:64)
@test sum(m) < sum(1:64)
if sum(m) < 1
@test sum(m) < 1
else
Expand All @@ -54,21 +64,23 @@ end
end

@testset "original" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
w′ = (α = rand(3, 3), β = rand(3, 3))
w = (α = 5rand(3, 3), β = rand(3, 3))
st = Optimisers.setup(o, w)
loss(x, y) = mean((x.α .* x.β .- y.α .* y.β) .^ 2)
@test loss(w, w′) > 1
for i = 1:10^4
gs = gradient(x -> loss(x, w′), w)
gs = loggradient(o)(x -> loss(x, w′), w)
st, w = Optimisers.update(st, w, gs...)
end
@test loss(w, w′) < 0.001
end
end

@testset verbose=true "StaticArrays" begin
empty!(LOG)
@testset "$(name(o))" for o in RULES
W1 = @SMatrix randn(10, 10)
b1 = @SVector randn(10)
Expand All @@ -82,7 +94,7 @@ end
@test s_loss(model, x, y) > 10
state = Optimisers.setup(o, model)
for t = 1:10^3
g = gradient(m -> s_loss(m, x, y), model)[1]
g = loggradient(o)(m -> s_loss(m, x, y), model)[1]
state, model = Optimisers.update!(state, model, g)
end
if o isa Descent
Expand All @@ -94,7 +106,7 @@ end
end
end

@testset verbose=true "element types" begin
@testset "element types" begin
@testset "$(name(o))" for o in RULES
marray = (Float32[1,2], Float64[3,4], Float16[5,6])
types = map(eltype, marray)
Expand Down Expand Up @@ -166,3 +178,55 @@ end
end
end

@testset "with complex numebers: Flux#1776" begin
empty!(LOG)
@testset "$(name(opt))" for opt in [
# The Flux PR had 1e-2 for all. But ADADelta(ρ) needs ρ≈0.9 not small. And it helps to make ε not too small too:
ADAM(1e-2), RMSProp(1e-2), RADAM(1e-2), OADAM(1e-2), ADAGrad(1e-2), ADADelta(0.9, 1e-5), NADAM(1e-2), AdaBelief(1e-2),
Comment on lines +184 to +185
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix, BTW.

# These weren't in Flux PR:
Descent(1e-2), Momentum(1e-2), Nesterov(1e-2), ADAMW(1e-2),
]
# Our "model" is just a complex number
model = (w = zeros(ComplexF64, 1),)

# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
function loss(m)
# Deterministic training data is the best training data
x = ones(1, 1) + 1im*ones(1, 1)
# Manually implement `mse()` to allow demonstration of brokenness
# on older Flux builds that don't have a fixed `mse()`
return sum(abs2.(m.w * x .- conj(x)))
end
@test loss(model) ≈ 2.0

state = Optimisers.setup(opt, model)

# Train for 10 iterations, enforcing that loss is monotonically decreasing
last_loss = Inf
for idx in 1:10
grads = loggradient(opt)(loss, model)
state, model = Optimisers.update!(state, model, grads...)
opt isa Union{Momentum, Nesterov} && idx > 8 && continue # these are very flat at the end
@test loss(model) < last_loss
last_loss = loss(model)
end
@test loss(model) < 1.9

# Repeat with StaticArrays
static_model = (w = SA[0.0 + 0im],)
static_state = Optimisers.setup(opt, static_model)
function static_loss(m)
x = hcat(SA[1.0 + im])
sum(abs2.(m.w * x .- conj(x)))
end
last_loss = Inf
for idx in 1:10
grads = gradient(static_loss, static_model)
static_state, static_model = Optimisers.update!(static_state, static_model, grads...)
opt isa Union{Momentum, Nesterov} && idx > 8 && continue
@test static_loss(static_model) < last_loss
last_loss = static_loss(static_model)
end
@test static_loss(static_model) < 1.9
end
end