Skip to content

Commit

Permalink
complex numbers alla Flux 1776
Browse files Browse the repository at this point in the history
  • Loading branch information
mcabbott committed Jan 30, 2022
1 parent 39cca65 commit 8e3a8a9
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 11 deletions.
18 changes: 9 additions & 9 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ init(o::RMSProp, x::AbstractArray) = zero(x)

function apply!(o::RMSProp, state, x, dx)
η, ρ, ϵ, acc = o.eta, o.rho, o.epsilon, state
acc′ = @.. acc = ρ * acc + (1 - ρ) * dx^2
acc′ = @.. acc = ρ * acc + (1 - ρ) * abs2(dx)
dx′ = @.. dx */ (sqrt(acc) + ϵ))

return acc′, dx′
Expand Down Expand Up @@ -134,7 +134,7 @@ function apply!(o::ADAM{T}, state, x, dx) where T
mt, vt, βt = state

mt′ = @.. mt = β[1] * mt + (one(T) - β[1]) * dx
vt′ = @.. vt = β[2] * vt + (one(T) - β[2]) * dx ^ 2
vt′ = @.. vt = β[2] * vt + (one(T) - β[2]) * abs2(dx)
dx′ = @.. mt / (one(T) - βt[1]) / (sqrt(vt / (one(T) - βt[2])) + ϵ) * η

return (mt′, vt′, βt .* β), dx′
Expand Down Expand Up @@ -169,7 +169,7 @@ function apply!(o::RADAM, state, x, dx)
mt, vt, βt, t = state

mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
ρ = ρ∞ - 2*t * βt[2] / (1 - βt[2])
if ρ > 4
r = sqrt((ρ - 4) *- 2) * ρ∞/((ρ∞ - 4) * (ρ∞ - 2) * ρ))
Expand Down Expand Up @@ -244,7 +244,7 @@ function apply!(o::OADAM, state, x, dx)
mt, vt, βt, dx_ = state

mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
dx = @.. -dx_
dx_′ = @.. dx_ = η * mt / (1 - βt[1]) / (sqrt(vt / (1 - βt[2])) + ϵ)
dx′ = @.. dx + 2*dx_
Expand Down Expand Up @@ -277,7 +277,7 @@ function apply!(o::ADAGrad, state, x, dx)
η, ϵ = o.eta, o.epsilon
acc = state

acc′ = @.. acc = acc + dx^2
acc′ = @.. acc = acc + abs2(dx)
dx′ = @.. dx * η / (sqrt(acc) + ϵ)

return acc′, dx′
Expand Down Expand Up @@ -307,11 +307,11 @@ function apply!(o::ADADelta, state, x, dx)
ρ, ϵ = o.rho, o.epsilon
acc, Δacc = state

acc′ = @.. acc = ρ * acc + (1 - ρ) * dx^2
acc′ = @.. acc = ρ * acc + (1 - ρ) * abs2(dx)
# DON'T remove epsilon from numerator
# or even out of the square roots
dx′ = @.. dx * sqrt(Δacc + ϵ) / sqrt(acc + ϵ)
Δacc′ = @.. Δacc = ρ * Δacc + (1 - ρ) * dx^2
Δacc′ = @.. Δacc = ρ * Δacc + (1 - ρ) * abs2(dx)

return (acc′, Δacc′), dx′
end
Expand Down Expand Up @@ -382,7 +382,7 @@ function apply!(o::NADAM, state, x, dx)
mt, vt, βt = state

mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * dx^2
vt′ = @.. vt = β[2] * vt + (1 - β[2]) * abs2(dx)
dx′ = @.. (β[1] * mt / (1 - β[1] * βt[1]) + (1 - β[1]) * dx / (1 - βt[1])) /
(sqrt(vt * β[2] / (1 - βt[2])) + ϵ) * η

Expand Down Expand Up @@ -435,7 +435,7 @@ function apply!(o::AdaBelief, state, x, dx)
mt, st = state

mt′ = @.. mt = β[1] * mt + (1 - β[1]) * dx
st′ = @.. st = β[2] * st + (1 - β[2]) * (dx - mt)^2
st′ = @.. st = β[2] * st + (1 - β[2]) * abs2(dx - mt)
dx′ = @.. η * mt / (sqrt(st) + ϵ)

return (mt′, st′), dx′
Expand Down
65 changes: 63 additions & 2 deletions test/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,14 @@ end
end

#=
plot(LOG[:ADAGrad]) # decline
using Plots
_plot(s; kw...) = (plot(); _plot!(s; kw...))
_plot!(s; kw...) = plot!(LOG[s]; label=string(s), yguide="loss", xguide="iter", kw...)
_plot(:ADAGrad) # decline
LOG[:ADAGrad][end] # 3869.4075f0
plot(LOG[:AMSGrad]) # decline
_plot!(:AMSGrad) # decline
LOG[:AMSGrad][end] # 2742.004f0
findfirst(isnan, LOG[:ADADelta]) # 182
Expand Down Expand Up @@ -185,3 +189,60 @@ end
end
end

@testset "with complex numebers: Flux#1776" begin
empty!(LOG)
@testset "$(name(f(1e-2)))" for f in [
ADAM, RMSProp, RADAM, OADAM, ADAGrad, ADADelta, NADAM, AdaBelief,
Descent, Momentum, Nesterov, ADAMW, # not in Flux PR
]
# Our "model" is just a complex number
model = (w = zeros(ComplexF64, 1),)

# Our model attempts to learn `f(x) = conj(x)` where `f(x) = w*x`
function loss(m)
# Deterministic training data is the best training data
x = ones(1, 1) + 1im*ones(1, 1)
# Manually implement `mse()` to allow demonstration of brokenness
# on older Flux builds that don't have a fixed `mse()`
return sum(abs2.(m.w * x .- conj(x)))
end

opt = f(1e-2)
state = Optimisers.setup(opt, model)

# Train for 10 iterations, enforcing that loss is monotonically decreasing
last_loss = Inf
for idx in 1:10
grads = loggradient(opt)(loss, model)
state, model = Optimisers.update!(state, model, grads...)
@test loss(model) < last_loss
last_loss = loss(model)
end

# Repeat with StaticArrays
static_model = (w = SA[1.0 + 0im],)
static_state = Optimisers.setup(opt, static_model)
function static_loss(m)
x = hcat(SA[1.0 + im])
sum(abs2.(m.w * x .- conj(x)))
end
last_loss = Inf
for idx in 1:10
grads = gradient(static_loss, static_model)
static_state, static_model = Optimisers.update!(static_state, static_model, grads...)
@test loss(static_model) < last_loss
last_loss = loss(static_model)
end
end
end

#=
_plot(:ADAM)
_plot!(:RADAM)
_plot!(:OADAM) # stays at 2
_plot(:RMSProp)
_plot(:ADADelta, yaxis=:log10) # exp growth
=#

0 comments on commit 8e3a8a9

Please sign in to comment.