Skip to content

Commit 2ea0d35

Browse files
bors[bot]scimas
andauthored
Merge #1299
1299: Fix ADADelta calculations and broken tests not catching the problems r=DhairyaLGandhi a=scimas 29832ac broke ADADelta, simply reversing its change and adding comments to not move around the epsilons. Fixes #1158 The testset probably needs more work than specifically for this case, so I'm not adding it here. ### PR Checklist - [x] Tests are ~added~ fixed - [ ] Entry in NEWS.md - [x] Documentation, if applicable - [ ] Final review from `@dhairyagandhi96` (for API changes). Co-authored-by: Mihir Gadgil <16473290+scimas@users.noreply.github.com>
2 parents 1dc9023 + 50092c8 commit 2ea0d35

File tree

2 files changed

+12
-3
lines changed

2 files changed

+12
-3
lines changed

src/optimise/optimisers.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ mutable struct OADAM
283283
state::IdDict
284284
end
285285

286-
OADAM= 0.0001, β = (0.5, 0.9)) = OADAM(η, β, IdDict())
286+
OADAM= 0.001, β = (0.5, 0.9)) = OADAM(η, β, IdDict())
287287

288288
function apply!(o::OADAM, x, Δ)
289289
η, β = o.eta, o.beta
@@ -357,7 +357,9 @@ function apply!(o::ADADelta, x, Δ)
357357
ρ = o.rho
358358
acc, Δacc = get!(o.state, x, (zero(x), zero(x)))
359359
@. acc = ρ * acc + (1 - ρ) * Δ^2
360-
@. Δ *= Δacc/ (acc + ϵ)
360+
# DON'T remove epsilon from numerator
361+
# or even out of the square roots
362+
@. Δ *= (Δacc + ϵ) / (acc + ϵ)
361363
@. Δacc = ρ * Δacc + (1 - ρ) * Δ^2
362364
return Δ
363365
end
@@ -599,4 +601,4 @@ function apply!(o::ClipNorm, x, Δ)
599601
rmul!(Δ, o.thresh / Δnrm)
600602
end
601603
return Δ
602-
end
604+
end

test/optimise.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,17 @@ using Flux.Optimise
22
using Flux.Optimise: runall
33
using Flux: Params, gradient
44
using Test
5+
using Random
56

67
@testset "Optimise" begin
8+
# Ensure rng has different state inside and outside the inner @testset
9+
# so that w and w' are different
10+
Random.seed!(84)
711
w = randn(10, 10)
812
@testset for opt in [ADAMW(), ADAGrad(0.1), AdaMax(), ADADelta(0.9), AMSGrad(),
913
NADAM(), RADAM(), Descent(0.1), ADAM(), OADAM(), Nesterov(), RMSProp(),
1014
Momentum()]
15+
Random.seed!(42)
1116
w′ = randn(10, 10)
1217
loss(x) = Flux.Losses.mse(w*x, w′*x)
1318
for t = 1: 10^5
@@ -21,8 +26,10 @@ using Test
2126
end
2227

2328
@testset "Optimiser" begin
29+
Random.seed!(84)
2430
w = randn(10, 10)
2531
@testset for Opt in [InvDecay, WeightDecay, ExpDecay]
32+
Random.seed!(42)
2633
w′ = randn(10, 10)
2734
loss(x) = Flux.Losses.mse(w*x, w′*x)
2835
opt = Optimiser(Opt(), ADAM(0.001))

0 commit comments

Comments
 (0)