Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ClipValue and ClipNorm #1133

Merged
merged 13 commits into from
May 15, 2020
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ Colors = "5ae59095-9a9b-59fe-a467-6f913c188581"
CuArrays = "3a865a2d-5b23-5a0f-bc46-62713ec82fae"
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
Juno = "e5e0dc1b-0480-54bc-9374-aad01c23163d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
AStupidBear marked this conversation as resolved.
Show resolved Hide resolved
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Expand Down
6 changes: 4 additions & 2 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ module Flux
# Zero Flux Given

using Base: tail
using Zygote, MacroTools, Juno, Reexport, Statistics, Random
using Statistics, Random, LinearAlgebra
using Zygote, MacroTools, Juno, Reexport
using MacroTools: @forward
@reexport using NNlib
using Zygote: Params, @adjoint, gradient, pullback, @nograd
Expand All @@ -20,7 +21,8 @@ using .Optimise
using .Optimise: @epochs
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
ClipValue, ClipNorm


using CuArrays
Expand Down
10 changes: 6 additions & 4 deletions src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
module Optimise

export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser
using LinearAlgebra

export train!, update!, stop, Optimiser,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM,
InvDecay, ExpDecay, WeightDecay, ClipValue, ClipNorm

include("optimisers.jl")
include("train.jl")
Expand Down
28 changes: 28 additions & 0 deletions src/optimise/optimisers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -533,3 +533,31 @@ function apply!(o::WeightDecay, x, Δ)
wd = o.wd
@. Δ += wd * x
end

"""
ClipValue(thresh)

Clip gradients when their absolute value exceeds `thresh`.
"""
mutable struct ClipValue{T}
thresh::T
end

apply!(o::ClipValue, x, Δ) = clamp!(Δ, -o.thresh, o.thresh)

"""
ClipNorm(thresh)

Clip gradients when their L2 norm exceeds `thresh`.
"""
mutable struct ClipNorm{T}
thresh::T
end

function apply!(o::ClipNorm, x, Δ)
Δnrm = norm(Δ, 2)
AStupidBear marked this conversation as resolved.
Show resolved Hide resolved
if Δnrm > o.thresh
rmul!(Δ, o.thresh / Δnrm)
end
return Δ
end
12 changes: 12 additions & 0 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,15 @@ end
@test decay_steps == ground_truth
@test o.eta == o.clip
end

@testset "Clipping" begin
w = randn(10, 10)
loss(x) = sum(w * x)
θ = Params([w])
x = 1000 * randn(10)
w̄ = gradient(() -> loss(x), θ)[w]
w̄_value = Optimise.apply!(ClipValue(1.0), w, copy(w̄))
@test all(w̄_value .<= 1)
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄))
@test norm(w̄_norm) <= 1
end