Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Documenting Optimiser Interface #904

Closed
wants to merge 13 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/src/saving.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,6 @@ You can even store optimiser state alongside the model, to resume training
exactly where you left off.

```julia
opt = ADAM(params(model))
opt = ADAM()
@save "model-$(now()).bson" model opt
```
80 changes: 80 additions & 0 deletions docs/src/training/optimisers.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,83 @@ AMSGrad
NADAM
ADAMW
```

## Optimiser Interface

Flux's optimsers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient.

In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work this with a simple example.

```julia
mutable struct Momentum
eta
rho
velocity
end

Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict())
```

The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked.

```julia
function apply!(o::Momentum, x, Δ)
η, ρ = o.eta, o.rho
v = get!(o.velocity, x, zero(x))::typeof(x)
@. v = ρ * v - η * Δ
@. Δ = -v
end
```

This is the basic definition of a Momentum update rule given by:

```math
v = ρ * v - η * Δ
w = w - v
```

The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser.

Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully.

## Composing Optimisers

Flux defines a special kind of optimiser called simply as `Optimiser` which takes in a arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient
that will be fed into the next, and the resultant update will be applied to the parameter as usual. A classic use case is where adding decays is desirable. Flux defines some basic decays including `ExpDecay`, `InvDecay` etc.

```julia
opt = Optimiser(ExpDecay(0.001, 0.1, 1000, 1e-4), Descent())
```

Here we apply exponential decay to the `Descent` optimser. The defaults of `ExpDecay` say that its learning rate will be decayed every 1000 steps.
It is then applied like any optimser.

```julia
w = randn(10, 10)
w1 = randn(10,10)
ps = Params([w, w1])

loss(x) = Flux.mse(w * x, w1 * x)

loss(rand(10)) # around 9

for t = 1:10^5
θ = Params([w, w1])
θ̄ = gradient(() -> loss(rand(10)), θ)
Flux.Optimise.update!(opt, θ, θ̄)
end

loss(rand(10)) # around 0.9
```

In this manner it is possible to compose optimisers for some added flexibility.

## Decays

Similar to optimisers, Flux also defines some simple decays that can be used in conjunction with other optimisers, or standalone.

```@docs
ExpDecay
InvDecay
WeightDecay
```
Loading