Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimistic ADAM #1246

Merged
merged 4 commits into from
Jul 1, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
* Use `DataLoader` with `NamedTuple`s, so that tensors can be accessed by name [https://github.com/FluxML/Flux.jl/pull/1221].
* Error if Dense layers weights and biases are not arrays [https://github.com/FluxML/Flux.jl/pull/1218].
* Add `Adaptive Pooling` in Flux layers [https://github.com/FluxML/Flux.jl/pull/1239].
* Optimistic ADAM (OADAM) optimizer for adversarial training [https://github.com/FluxML/Flux.jl/pull/1246].

# v0.10.5
* Add option for [same padding](https://github.com/FluxML/Flux.jl/pull/901) to conv and pooling layers by setting `pad=SamePad()`.
Expand Down
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ using LinearAlgebra

export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser,
ClipValue, ClipNorm

Expand Down
39 changes: 39 additions & 0 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,45 @@ function apply!(o::AdaMax, x, Δ)
return Δ
end

"""
OADAM(η = 0.0001, β::Tuple = (0.5, 0.9))

[OADAM](https://arxiv.org/abs/1711.00141) (Optimistic ADAM)
is a variant of ADAM adding an "optimistic" term suitable for adversarial training.

# Parameters
- Learning rate (`η`): Amount by which gradients are discounted before updating
the weights.
- Decay of momentums (`β::Tuple`): Exponential decay for the first (β1) and the
second (β2) momentum estimate.

# Examples
```julia
opt = OADAM()

opt = OADAM(0.001, (0.9, 0.995))
```
"""
mutable struct OADAM
eta::Float64
beta::Tuple{Float64,Float64}
state::IdDict
end

OADAM(η = 0.0001, β = (0.5, 0.9)) = OADAM(η, β, IdDict())

function apply!(o::OADAM, x, Δ)
η, β = o.eta, o.beta
mt, vt, Δ_, βp = get!(o.state, x, (zero(x), zero(x), zero(x), β))
@. mt = β[1] * mt + (1 - β[1]) * Δ
@. vt = β[2] * vt + (1 - β[2]) * Δ^2
@. Δ = -Δ_
@. Δ_ = η * mt / (1 - βp[1]) / (√(vt / (1 - βp[2])) + ϵ)
@. Δ += 2Δ_
o.state[x] = (mt, vt, Δ_, βp .* β)
return Δ
end

"""
ADAGrad(η = 0.1)

Expand Down
2 changes: 1 addition & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Test
@testset "Optimise" begin
w = randn(10, 10)
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
NADAM(), RADAM(), Descent(0.1), ADAM(), Nesterov(), RMSProp(),
NADAM(), RADAM(), Descent(0.1), ADAM(), OADAM(), Nesterov(), RMSProp(),
Momentum()]
w′ = randn(10, 10)
loss(x) = Flux.mse(w*x, w′*x)
Expand Down