Skip to content

Commit

Permalink
Merge pull request #224 from tejank10/nadam-opt
Browse files Browse the repository at this point in the history
NADAM optimizer
  • Loading branch information
MikeInnes authored Jun 8, 2018
2 parents 4915b0c + 4a24b69 commit 9345607
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 6 deletions.
2 changes: 1 addition & 1 deletion src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
export SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM

include("utils.jl")
include("onehot.jl")
Expand Down
3 changes: 2 additions & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
module Optimise

export train!,
SGD, ADAM, AdaMax, Momentum, Nesterov, RMSProp, ADAGrad, ADADelta, AMSGrad
SGD, ADAM, AdaMax, Momentum, Nesterov,
RMSProp, ADAGrad, ADADelta, AMSGrad, NADAM

struct Param{T}
x::T
Expand Down
9 changes: 9 additions & 0 deletions src/optimise/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,12 @@ tuning.
"""
AMSGrad(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p -> amsgrad(p; η = η, β1 = β1, β2 = β2, ϵ = ϵ), p -> invdecay(p, decay), p -> descent(p, 1))

"""
NADAM(params, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0)
[NADAM](https://openreview.net/pdf?id=OM0jvwB8jIp57ZJjtNEZ) optimiser. Parameters other
than learning rate don't need tuning.
"""
NADAM(ps, η = 0.001; β1 = 0.9, β2 = 0.999, ϵ = 1e-08, decay = 0) =
optimiser(ps, p->nadam(p; η=η, β1=β1, β2=β2, ϵ=ϵ), p->invdecay(p,decay), p->descent(p,1))
19 changes: 16 additions & 3 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ function rmsprop(p::Param; η::Real = 0.001, ρ::Real = 0.9, ϵ::Real = 1e-8)
acc = zeros(p.x)
function ()
@. acc = ρ * acc + (1 - ρ) * p.Δ^2
@. p.Δ *= η / (acc + ϵ)
@. p.Δ *= η / (acc + ϵ)
end
end

function adagrad(p::Param; η::Real = 0.01, ϵ::Real = 1e-8)
acc = zeros(p.x) .+ ϵ
function ()
@. acc += p.Δ^2
@. p.Δ *= η / acc
@. p.Δ *= η / (acc + ϵ)
end
end

Expand All @@ -56,7 +56,7 @@ function adam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = mt / (1 - β1p) / ((vt / (1 - β2p)) + ϵ) * η
@. p.Δ = mt / (1 - β1p) / (vt / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
Expand Down Expand Up @@ -86,6 +86,19 @@ function amsgrad(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999,
end
end

function nadam(p::Param; η::Real = 0.001, β1::Real = 0.9, β2::Real = 0.999, ϵ::Real = 1e-8)
mt = zeros(p.x)
vt = zeros(p.x)
β1p, β2p = β1, β2
function ()
@. mt = β1 * mt + (1 - β1) * p.Δ
@. vt = β2 * vt + (1 - β2) * p.Δ^2
@. p.Δ = (β1 * mt / (1 - β1 * β1p) + (1 - β1) * p.Δ / (1 - β1p)) / (vt * β2 / (1 - β2p) + ϵ) * η
β1p *= β1
β2p *= β2
end
end

clip(p::Param, thresh::Real) = () -> clamp!(p.Δ, -thresh, thresh)

function expdecay(p::Param, γ::Real)
Expand Down
2 changes: 1 addition & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using Flux.Tracker

@testset "Optimise" begin
w = randn(10, 10)
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad]
@testset for Opt in [SGD, Nesterov, Momentum, ADAM, AdaMax, RMSProp, ps -> ADAGrad(ps, 0.1), ADADelta, AMSGrad, NADAM]
w′ = param(randn(10, 10))
loss(x) = Flux.mse(w*x, w′*x)
opt = Opt([w′])
Expand Down

0 comments on commit 9345607

Please sign in to comment.