Skip to content

Commit

Permalink
Stop training on Inf/NaN loss (#2070)
Browse files Browse the repository at this point in the history
* stop training on Inf/NaN loss

* add a test

* improve test

* improve test

* Update train.jl

* Update optimise.jl
  • Loading branch information
mcabbott authored Oct 16, 2022
1 parent 090f043 commit 4c38c8a
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 3 deletions.
11 changes: 8 additions & 3 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ProgressLogging: @progress, @withprogress, @logprogress
import Zygote: Params, gradient
import Zygote: Params, gradient, withgradient


"""
Expand Down Expand Up @@ -105,8 +105,10 @@ The optimiser should be from the `Flux.Optimise` module (see [Optimisers](@ref))
Different optimisers can be combined using [`Flux.Optimise.Optimiser`](@ref Flux.Optimiser).
This training loop iterates through `data` once.
It will stop with a `DomainError` if the loss is `NaN` or infinite.
You can use [`@epochs`](@ref) to do this several times, or
use for instance `Iterators.repeat` to make a longer `data` iterator.
use for instance `Itertools.ncycle` to make a longer `data` iterator.
## Callbacks
Expand All @@ -126,9 +128,12 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
@withprogress for (i, d) in enumerate(data)
try
gs = gradient(ps) do
l, gs = withgradient(ps) do
loss(batchmemaybe(d)...)
end
if !isfinite(l)
throw(DomainError("Loss is $l on data item $i, stopping training"))
end
update!(opt, ps, gs)
cb()
catch ex
Expand Down
12 changes: 12 additions & 0 deletions test/optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,18 @@ end
Flux.train!(loss, Flux.params(r), (r,), Descent())
end

@testset "Stop on NaN" begin
m = Dense(1 => 1)
m.weight .= 0
CNT = 0
@test_throws DomainError Flux.train!(Flux.params(m), 1:100, Descent(0.1)) do i
CNT += 1
(i == 51 ? NaN32 : 1f0) * sum(m([1.0]))
end
@test CNT == 51 # stopped early
@test m.weight[1] -5 # did not corrupt weights
end

@testset "ExpDecay" begin

@testset "Sanity Check" begin
Expand Down

0 comments on commit 4c38c8a

Please sign in to comment.