From 86f25ca3f68d2d80c3e14830dff5a78467d9b7f4 Mon Sep 17 00:00:00 2001 From: Moelf Date: Tue, 16 Jun 2020 16:37:10 -0700 Subject: [PATCH 1/3] add Flux.skip() --- src/optimise/train.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/src/optimise/train.jl b/src/optimise/train.jl index 7e7b516c9b..74b2200eb0 100644 --- a/src/optimise/train.jl +++ b/src/optimise/train.jl @@ -37,6 +37,26 @@ call(f, xs...) = f(xs...) runall(f) = f runall(fs::AbstractVector) = () -> foreach(call, fs) +struct SkipException <: Exception end + +""" + skip() + +Call `Flux.skip()` in a callback to indicate when a callback condition is met. +This will trigger the train loop to skip the current data point and not update with the calculated gradient. + +# Examples +```julia +cb = function () + loss() > 1e7 && Flux.skip() +end +``` +""" +function skip() + throw(SkipException()) +end + + struct StopException <: Exception end """ @@ -93,6 +113,8 @@ function train!(loss, ps, data, opt; cb = () -> ()) catch ex if ex isa StopException break + elseif ex isa SkipException + continue else rethrow(ex) end From 857566d8c94f355a494db8e8f8aaa799cc988c8a Mon Sep 17 00:00:00 2001 From: Moelf Date: Thu, 20 Aug 2020 14:41:05 -0400 Subject: [PATCH 2/3] add Flux.skip() test and docs --- docs/src/utilities.md | 1 + src/Flux.jl | 1 + src/optimise/Optimise.jl | 2 +- test/optimise.jl | 10 +++++++++- 4 files changed, 12 insertions(+), 2 deletions(-) diff --git a/docs/src/utilities.md b/docs/src/utilities.md index 7986ec238d..95ef098ea5 100644 --- a/docs/src/utilities.md +++ b/docs/src/utilities.md @@ -46,4 +46,5 @@ Flux.destructure ```@docs Flux.throttle Flux.stop +Flux.skip ``` diff --git a/src/Flux.jl b/src/Flux.jl index c28a7d3637..8d485ecfbd 100644 --- a/src/Flux.jl +++ b/src/Flux.jl @@ -19,6 +19,7 @@ export Chain, Dense, Maxout, RNN, LSTM, GRU, SamePad, Conv, CrossCor, ConvTransp include("optimise/Optimise.jl") using .Optimise using .Optimise: @epochs +using .Optimise: skip export Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW, RADAM, InvDecay, ExpDecay, WeightDecay, diff --git a/src/optimise/Optimise.jl b/src/optimise/Optimise.jl index 0f5e644f0d..37672706d2 100644 --- a/src/optimise/Optimise.jl +++ b/src/optimise/Optimise.jl @@ -5,7 +5,7 @@ using LinearAlgebra export train!, update!, Descent, ADAM, Momentum, Nesterov, RMSProp, ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, - InvDecay, ExpDecay, WeightDecay, stop, Optimiser, + InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser, ClipValue, ClipNorm include("optimisers.jl") diff --git a/test/optimise.jl b/test/optimise.jl index 8a338bcb66..5198444a9a 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -39,6 +39,14 @@ end @testset "Training Loop" begin i = 0 l = 1 + Flux.train!( + () -> (sleep(0.1); Flux.skip(); i+=1), + (), + Iterators.repeated((), 10), + Descent() + ) + + @test i==0 #all skipped Flux.train!(() -> (sleep(0.1); i += 1; l), (), @@ -110,4 +118,4 @@ end @test all(w̄_value .<= 1) w̄_norm = Optimise.apply!(ClipNorm(1.0), w, copy(w̄)) @test norm(w̄_norm) <= 1 -end \ No newline at end of file +end From 56fafa695151fcac2127c96493885b2b2c81a63c Mon Sep 17 00:00:00 2001 From: Moelf Date: Fri, 9 Oct 2020 10:50:47 -0400 Subject: [PATCH 3/3] more test --- test/optimise.jl | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/test/optimise.jl b/test/optimise.jl index e868d6a927..4d90ec8e94 100644 --- a/test/optimise.jl +++ b/test/optimise.jl @@ -55,6 +55,16 @@ end @test i==0 #all skipped + Flux.train!( + () -> (sleep(0.1); i==8 && Flux.skip(); i+=1), + (), + Iterators.repeated((), 10), + Descent() + ) + + @test i==8 #skip after i hit 8 + + i = 0 Flux.train!(() -> (sleep(0.1); i += 1; l), (), Iterators.repeated((), 100),