Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add step! #1833

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
## v0.12.9
* Fixed incorrect output and added GPU compatibility for [AlphaDropout](https://github.com/FluxML/Flux.jl/pull/1781).
* Add trilinear [Upsample layer](https://github.com/FluxML/Flux.jl/pull/1792).
* Add `optimstep!` as a single training step of `train!` to allow for more exotic
optimisers (#666)

## v0.12.8
* Optimized inference and gradient calculation of OneHotMatrix[pr](https://github.com/FluxML/Flux.jl/pull/1756)
Expand Down
32 changes: 29 additions & 3 deletions docs/src/training/optimisers.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,35 @@ AdaBelief

## Optimiser Interface

Flux's optimisers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient.
Flux's optimisers are built around a `struct` that holds all the optimiser
parameters along with a definition of how to apply the update
rule associated with it (`optimstep!`). The default implementation of `optimstep!`
looks like this

```julia
function optimstep!(loss, params, opt)
# Calculate the gradients of the parameters
# with respect to the loss function
val, grads = Flux.withgradient(loss, parameters)
# Update the parameters based on the chosen
# optimiser (opt)
Flux.Optimise.update!(opt, parameters, grads)
return val, grads
end
```

and therefore assumes that its update rule only requires the optimisers internal
state `opt`, the `parameters` themselves and the gradients `grads`. For
optimisers which do not fit this pattern, you want to overload `optimstep!`
itself.

In the following subsection we define a simple Momentum optimiser which fits the
`update!` pattern and therefore does not have to override `optimstep!` itself.

### Gradient Based Optimiser

To obtain an `update!` method applicable to your custom optimiser type, we
need to overload the `apply!` function. Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully. It takes the optimiser as the first argument followed by the parameter and its corresponding gradient.

In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example.

Expand Down Expand Up @@ -99,8 +127,6 @@ w = w - v

The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser.

Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully.

## Composing Optimisers

Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient
Expand Down
18 changes: 15 additions & 3 deletions docs/src/training/training.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,27 @@ for d in datapoints
# `d` should produce a collection of arguments
# to the loss function

# Calculate the gradients of the parameters
# with respect to the loss function
grads = Flux.gradient(parameters) do
# Update the parameters based on the chosen
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is right at the beginning instead of in the Custom Training Loop Section. It seems to me like the custom training loop section might either be redundant or demonstrate how to have a custom gradient calculation now.

# optimiser (opt)
loss, grads = optimstep!(params, opt) do
loss(d...)
end
end
```

`optimstep!` is the optimiser implementation and thus dispatches depending on
the optimizer type. As an example, the default `optimstep!` for optimiser who
use the gradient to update the parameters (e.g. gradient descent, momentum, ADAM, etc.) looks like this

```julia
function optimstep!(loss, params, opt)
# Calculate the gradients of the parameters
# with respect to the loss function
val, grads = Flux.withgradient(loss, parameters)
# Update the parameters based on the chosen
# optimiser (opt)
Flux.Optimise.update!(opt, parameters, grads)
return val, grads
end
```

Expand Down
2 changes: 1 addition & 1 deletion src/optimise/Optimise.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module Optimise
using LinearAlgebra
import ArrayInterface

export train!, update!,
export train!, optimstep!, update!,
Descent, ADAM, Momentum, Nesterov, RMSProp,
ADAGrad, AdaMax, ADADelta, AMSGrad, NADAM, ADAMW,RADAM, OADAM, AdaBelief,
InvDecay, ExpDecay, WeightDecay, stop, skip, Optimiser,
Expand Down
34 changes: 31 additions & 3 deletions src/optimise/train.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Juno
import Zygote: Params, gradient
import Zygote: Params, withgradient

"""
update!(x, x̄)
Expand Down Expand Up @@ -80,6 +80,35 @@ end
batchmemaybe(x) = tuple(x)
batchmemaybe(x::Tuple) = x

"""
optimstep!(loss, params, opt)
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suggest optimstep! instead of trainstep! to indicate that this is the optimiser interface and keep the ML jargon to a minimum

Copy link
Member

@mcabbott mcabbott Mar 20, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One vote for something evoking train! to stress that they are closely related.

If the longer-term plan is to use Optimisers.jl, this may not fit with train! at all -- some recent discussion here: #1902 (comment) . In which case there will be an implicit-style train! & Params story, and an explicit-style gradient and Optimisers.update!. With such a divide, this function wants to be clearly on the train! & Params side.

Maybe it should just be 3-arg train!? Without a data iterator, there is no iteration, that's all:

train!(loss, ::Params, data, ::AbstractOptimiser)  # calls loss(d...) for d in data
train!(loss, ::Params, ::AbstractOptimiser)        # calls loss() since there is no data


`optimstep!` uses a `loss` function (with no inputs) to improve the [Model parameters](@ref) (`params`)
based on a pluggable [Optimisers](@ref) (`opt`). It represents a single step in
the training loop `train!`.

The default implementation for `optimstep!` is takes the gradient of `loss`
and calls `Flux.Optimise.update!` to adjust the parameters, but you can overload
`optimstep!` for specific types of `opt`. This can be useful if your optimization routine
has does not follow the standard gradient descent procedure (e.g. gradient-free optimizers).

Unlike `train!`, the loss function of `optimstep!` accepts no input.
Instead, `train!` cycles through the data in a loop and calls `optimstep!`:
```julia
for d in data
optimstep!(ps, opt) do
loss(d)
end
end
```
If you are writing [Custom Training loops](@ref), then you should follow this pattern.
"""
function optimstep!(loss, params, opt)
val, gs = withgradient(loss, params)
darsnack marked this conversation as resolved.
Show resolved Hide resolved
update!(opt, params, gs)
return val, gs
end

"""
train!(loss, params, data, opt; cb)

Expand All @@ -106,10 +135,9 @@ function train!(loss, ps, data, opt; cb = () -> ())
cb = runall(cb)
@progress for d in data
try
gs = gradient(ps) do
optimstep!(ps, opt) do
loss(batchmemaybe(d)...)
end
update!(opt, ps, gs)
cb()
catch ex
if ex isa StopException
Expand Down