@@ -149,7 +149,7 @@ function update!(o::ADAGrad, x, Δ)
149149 η = o. eta
150150 acc = get! (o. acc, x, fill (ϵ, size (x))):: typeof (x)
151151 @. acc += Δ^ 2
152- @. Δ *= η / √ ( acc + ϵ)
152+ @. Δ *= η / ( √ acc + ϵ)
153153end
154154
155155"""
@@ -169,7 +169,7 @@ function update!(o::ADADelta, x, Δ)
169169 ρ = o. rho
170170 acc, Δacc = get! (o. state, x, (zero (x), zero (x)))
171171 @. acc = ρ * acc + (1 - ρ) * Δ^ 2
172- @. Δ *= √ ( Δacc + ϵ) / √ ( acc + ϵ)
172+ @. Δ *= √ Δacc/ ( √ acc + ϵ)
173173 @. Δacc = ρ * Δacc + (1 - ρ) * Δ^ 2
174174 return Δ
175175end
@@ -194,7 +194,7 @@ function update!(o::AMSGrad, x, Δ)
194194 @. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
195195 @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ ^ 2
196196 @. v̂t = max .(v̂t, vt)
197- @. Δ = η * mt / √ v̂t
197+ @. Δ = η * mt / ( √ v̂t + ϵ)
198198end
199199
200200"""
@@ -217,7 +217,7 @@ function update!(o::NADAM, x, Δ)
217217 mt, vt = get! (o. state, x, (zero (x), zero (x)))
218218 @. mt = β[1 ] * mt + (1 - β[1 ]) * Δ
219219 @. vt = β[2 ] * vt + (1 - β[2 ]) * Δ^ 2
220- @. Δ = (β[1 ] * mt / (1 - β[1 ] * β1p) + (1 - β[1 ]) * Δ / (1 - β1p)) / √ (vt * β[2 ] / (1 - β2p) + ϵ) * η
220+ @. Δ = (β[1 ] * mt / (1 - β[1 ] * β1p) + (1 - β[1 ]) * Δ / (1 - β1p)) / ( √ (vt * β[2 ] / (1 - β2p) ) + ϵ) * η
221221 o. state[x] = (mt, vt, (β1p * β[1 ], β2p * β[2 ]))
222222 return Δ
223223end
0 commit comments