Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add step! #1833

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## v0.12.9
* Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781).
* Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792).
* Add `step!` as a single training step of `train!` to allow for more exotic
FelixBenning marked this conversation as resolved.
Show resolved Hide resolved
optimisers (#666)

## v0.12.8
* Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756)
Expand Down
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Optimise
using LinearAlgebra
import ArrayInterface

export train!, update!,
export train!, step!, update!,
FelixBenning marked this conversation as resolved.
Show resolved Hide resolved
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief,
InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
Expand Down
32 changes: 29 additions & 3 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Juno
import Zygote: Params, gradient
import Zygote: Params, withgradient

"""
update!(x, x̄)
Expand Down Expand Up @@ -80,6 +80,33 @@ end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

"""
step!(loss, params, opt)

`step!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`)
based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in
the training loop `train!`. While there is a default implementation for
optimisers which are based on the `update!` function and only require gradient
information, this `step!` has to be overloaded for more general optimisers.
FelixBenning marked this conversation as resolved.
Show resolved Hide resolved

While the loss function of `train!` still accepts data as input, the loss function
of `step!` accepts no input. `train!` cycles through the data in a loop
roughly like this

```julia
for d in data
step!(ps, opt) do
loss(d)
end
```

FelixBenning marked this conversation as resolved.
Show resolved Hide resolved
"""
function step!(loss, params, opt)
val, gs = withgradient(loss, params)
darsnack marked this conversation as resolved.
Show resolved Hide resolved
update!(opt, params, gs)
return val, gs
end

"""
train!(loss, params, data, opt; cb)

Expand All @@ -106,10 +133,9 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
step!(ps, opt) do
loss(batchmemaybe(d)...)
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
Expand Down