Skip to content

Commit

Permalink
Merge pull request #1819 from cossio/eps
Browse files Browse the repository at this point in the history
make eps a parameter of optimisers
  • Loading branch information
CarloLucibello authored Dec 29, 2021
2 parents 2399588 + 43279cc commit 7f375aa
Showing 1 changed file with 32 additions and 22 deletions.
54 changes: 32 additions & 22 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -133,16 +133,17 @@ opt = RMSProp(0.002, 0.95)
mutable struct RMSProp <: AbstractOptimiser
eta::Float64
rho::Float64
epsilon::Float64
acc::IdDict
end

RMSProp= 0.001, ρ = 0.9) = RMSProp(η, ρ, IdDict())
RMSProp= 0.001, ρ = 0.9, ϵ = ϵ) = RMSProp(η, ρ, ϵ, IdDict())

function apply!(o::RMSProp, x, Δ)
η, ρ = o.eta, o.rho
acc = get!(() -> zero(x), o.acc, x)::typeof(x)
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
@. Δ *= η / (acc + ϵ)
@. Δ *= η / (acc + o.epsilon)
end

"""
Expand All @@ -166,10 +167,11 @@ opt = ADAM(0.001, (0.9, 0.8))
mutable struct ADAM <: AbstractOptimiser
eta::Float64
beta::Tuple{Float64,Float64}
epsilon::Float64
state::IdDict
end

ADAM= 0.001, β = (0.9, 0.999)) = ADAM(η, β, IdDict())
ADAM= 0.001, β = (0.9, 0.999), ϵ = ϵ) = ADAM(η, β, ϵ, IdDict())

function apply!(o::ADAM, x, Δ)
η, β = o.eta, o.beta
Expand All @@ -180,7 +182,7 @@ function apply!(o::ADAM, x, Δ)

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + o.epsilon) * η
βp .= βp .* β

return Δ
Expand All @@ -207,10 +209,11 @@ opt = RADAM(0.001, (0.9, 0.8))
mutable struct RADAM <: AbstractOptimiser
eta::Float64
beta::Tuple{Float64,Float64}
epsilon::Float64
state::IdDict
end

RADAM= 0.001, β = (0.9, 0.999)) = RADAM(η, β, IdDict())
RADAM= 0.001, β = (0.9, 0.999), ϵ = ϵ) = RADAM(η, β, ϵ, IdDict())

function apply!(o::RADAM, x, Δ)
η, β = o.eta, o.beta
Expand All @@ -225,7 +228,7 @@ function apply!(o::RADAM, x, Δ)
ρ = ρ∞ - 2t[] * βp[2] / (1 - βp[2])
if ρ > 4
r = sqrt((ρ-4)*-2)*ρ∞/((ρ∞-4)*(ρ∞-2)*ρ))
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ) * η * r
@. Δ = mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + o.epsilon) * η * r
else
@. Δ = mt / (1 - βp[1]) * η
end
Expand Down Expand Up @@ -256,10 +259,11 @@ opt = AdaMax(0.001, (0.9, 0.995))
mutable struct AdaMax <: AbstractOptimiser
eta::Float64
beta::Tuple{Float64,Float64}
epsilon::Float64
state::IdDict
end

AdaMax= 0.001, β = (0.9, 0.999)) = AdaMax(η, β, IdDict())
AdaMax= 0.001, β = (0.9, 0.999), ϵ = ϵ) = AdaMax(η, β, ϵ, IdDict())

function apply!(o::AdaMax, x, Δ)
η, β = o.eta, o.beta
Expand All @@ -270,7 +274,7 @@ function apply!(o::AdaMax, x, Δ)

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. ut = max(β[2] * ut, abs(Δ))
@. Δ =/(1 - βp[1])) * mt/(ut + ϵ)
@. Δ =/(1 - βp[1])) * mt/(ut + o.epsilon)
βp .= βp .* β

return Δ
Expand Down Expand Up @@ -298,10 +302,11 @@ opt = OADAM(0.001, (0.9, 0.995))
mutable struct OADAM <: AbstractOptimiser
eta::Float64
beta::Tuple{Float64,Float64}
epsilon::Float64
state::IdDict
end

OADAM= 0.001, β = (0.5, 0.9)) = OADAM(η, β, IdDict())
OADAM= 0.001, β = (0.5, 0.9), ϵ = ϵ) = OADAM(η, β, ϵ, IdDict())

function apply!(o::OADAM, x, Δ)
η, β = o.eta, o.beta
Expand All @@ -313,7 +318,7 @@ function apply!(o::OADAM, x, Δ)
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
@. Δ = -Δ_
@. Δ_ = η * mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + ϵ)
@. Δ_ = η * mt / (1 - βp[1]) / ((vt / (1 - βp[2])) + o.epsilon)
@. Δ += 2Δ_
βp .= βp .* β

Expand All @@ -340,16 +345,17 @@ opt = ADAGrad(0.001)
"""
mutable struct ADAGrad <: AbstractOptimiser
eta::Float64
epsilon::Float64
acc::IdDict
end

ADAGrad= 0.1) = ADAGrad(η, IdDict())
ADAGrad= 0.1, ϵ = ϵ) = ADAGrad, ϵ, IdDict())

function apply!(o::ADAGrad, x, Δ)
η = o.eta
acc = get!(() -> fill!(similar(x), ϵ), o.acc, x)::typeof(x)
acc = get!(() -> fill!(similar(x), o.epsilon), o.acc, x)::typeof(x)
@. acc += Δ * conj(Δ)
@. Δ *= η / (acc + ϵ)
@. Δ *= η / (acc + o.epsilon)
end

"""
Expand All @@ -371,18 +377,19 @@ opt = ADADelta(0.89)
"""
mutable struct ADADelta <: AbstractOptimiser
rho::Float64
epsilon::Float64
state::IdDict
end

ADADelta= 0.9) = ADADelta(ρ, IdDict())
ADADelta= 0.9, ϵ = ϵ) = ADADelta, ϵ, IdDict())

function apply!(o::ADADelta, x, Δ)
ρ = o.rho
acc, Δacc = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
@. acc = ρ * acc + (1 - ρ) * Δ * conj(Δ)
# DON'T remove epsilon from numerator
# or even out of the square roots
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
@. Δ *= (Δacc + o.epsilon) / (acc + o.epsilon)
@. Δacc = ρ * Δacc + (1 - ρ) * Δ * conj(Δ)
return Δ
end
Expand All @@ -409,22 +416,23 @@ opt = AMSGrad(0.001, (0.89, 0.995))
mutable struct AMSGrad <: AbstractOptimiser
eta::Float64
beta::Tuple{Float64, Float64}
epsilon::Float64
state::IdDict
end

AMSGrad= 0.001, β = (0.9, 0.999)) = AMSGrad(η, β, IdDict())
AMSGrad= 0.001, β = (0.9, 0.999), ϵ = ϵ) = AMSGrad(η, β, ϵ, IdDict())

function apply!(o::AMSGrad, x, Δ)
η, β = o.eta, o.beta

mt, vt, v̂t = get!(o.state, x) do
(fill!(similar(x), ϵ), fill!(similar(x), ϵ), fill!(similar(x), ϵ))
(fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon), fill!(similar(x), o.epsilon))
end :: NTuple{3,typeof(x)}

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ ^ 2
@. v̂t = max(v̂t, vt)
@. Δ = η * mt / (v̂t + ϵ)
@. Δ = η * mt / (v̂t + o.epsilon)
end

"""
Expand All @@ -449,10 +457,11 @@ opt = NADAM(0.002, (0.89, 0.995))
mutable struct NADAM <: AbstractOptimiser
eta::Float64
beta::Tuple{Float64, Float64}
epsilon::Float64
state::IdDict
end

NADAM= 0.001, β = (0.9, 0.999)) = NADAM(η, β, IdDict())
NADAM= 0.001, β = (0.9, 0.999), ϵ = ϵ) = NADAM(η, β, ϵ, IdDict())

function apply!(o::NADAM, x, Δ)
η, β = o.eta, o.beta
Expand All @@ -464,7 +473,7 @@ function apply!(o::NADAM, x, Δ)

@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ * conj(Δ)
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + ϵ) * η
@. Δ = (β[1] * mt / (1 - β[1] * β1p) + (1 - β[1]) * Δ / (1 - β1p)) / ((vt * β[2] / (1 - β2p)) + o.epsilon) * η
βp .= βp .* β

return Δ
Expand Down Expand Up @@ -515,17 +524,18 @@ opt = AdaBelief(0.001, (0.9, 0.8))
mutable struct AdaBelief
eta::Float64
beta::Tuple{Float64,Float64}
epsilon::Float64
state::IdDict
end

AdaBelief= 0.001, β = (0.9, 0.999)) = AdaBelief(η, β, IdDict())
AdaBelief= 0.001, β = (0.9, 0.999), ϵ = ϵ) = AdaBelief(η, β, ϵ, IdDict())

function apply!(o::AdaBelief, x, Δ)
η, β = o.eta, o.beta
mt, st = get!(() -> (zero(x), zero(x)), o.state, x)::NTuple{2,typeof(x)}
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. st = β[2] * st + (1 - β[2]) *- mt) * conj- mt)
@. Δ = η * mt / ((st) + ϵ)
@. Δ = η * mt / ((st) + o.epsilon)
return Δ
end

Expand Down

0 comments on commit 7f375aa

Please sign in to comment.