Skip to content

Commit

Permalink
Merge #1232
Browse files Browse the repository at this point in the history
1232: add Flux.skip() r=DhairyaLGandhi a=Moelf

per #821

### PR Checklist

- [x] Tests are added
- [ ] Entry in NEWS.md
- [x] Documentation, if applicable
- [ ] Final review from `@MikeInnes` or `@dhairyagandhi96` (for API changes).


Co-authored-by: Moelf <jerryling315@gmail.com>
Co-authored-by: Moelf <proton@jling.dev>
Co-authored-by: Jerry Ling <proton@jling.dev>
  • Loading branch information
3 people authored Oct 9, 2020
2 parents 5d9b2ca + 56fafa6 commit 9ed04bb
Show file tree
Hide file tree
Showing 5 changed files with 44 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/src/utilities.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,5 @@ Flux.destructure
```@docs
Flux.throttle
Flux.stop
Flux.skip
```
1 change: 1 addition & 0 deletions src/Flux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTransp
include("optimise/Optimise.jl")
using .Optimise
using .Optimise: @epochs
using .Optimise: skip
export Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, OADAM,
ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay,
Expand Down
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using LinearAlgebra
export train!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM,
InvDecay, ExpDecay, WeightDecay, stop, Optimiser,
InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
ClipValue, ClipNorm

include("optimisers.jl")
Expand Down
22 changes: 22 additions & 0 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,26 @@ call(f, xs...) = f(xs...)
runall(f) = f
runall(fs::AbstractVector) = () -> foreach(call, fs)

struct SkipException <: Exception end

"""
skip()
Call `Flux.skip()` in a callback to indicate when a callback condition is met.
This will trigger the train loop to skip the current data point and not update with the calculated gradient.
# Examples
```julia
cb = function ()
loss() > 1e7 && Flux.skip()
end
```
"""
function skip()
throw(SkipException())
end


struct StopException <: Exception end

"""
Expand Down Expand Up @@ -87,6 +107,8 @@ function train!(loss, ps, data, opt; cb = () -> ())
catch ex
if ex isa StopException
break
elseif ex isa SkipException
continue
else
rethrow(ex)
end
Expand Down
20 changes: 19 additions & 1 deletion test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,25 @@ end
@testset "Training Loop" begin
i = 0
l = 1
Flux.train!(
() -> (sleep(0.1); Flux.skip(); i+=1),
(),
Iterators.repeated((), 10),
Descent()
)

@test i==0 #all skipped

Flux.train!(
() -> (sleep(0.1); i==8 && Flux.skip(); i+=1),
(),
Iterators.repeated((), 10),
Descent()
)

@test i==8 #skip after i hit 8

i = 0
Flux.train!(() -> (sleep(0.1); i += 1; l),
(),
Iterators.repeated((), 100),
Expand Down Expand Up @@ -121,4 +139,4 @@ end
@test all(w̄_value .<= 1)
w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄))
@test norm(w̄_norm) <= 1
end
end

0 comments on commit 9ed04bb

Please sign in to comment.