Skip to content

Commit 89074bc

Browse files
committed
tweaks
1 parent 93a1a96 commit 89074bc

File tree

4 files changed

+35
-8
lines changed

4 files changed

+35
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ MacroTools = "0.5"
3333
NNlib = "0.8.9"
3434
NNlibCUDA = "0.2.4"
3535
OneHotArrays = "0.1, 0.2"
36-
Optimisers = "0.2.10"
36+
Optimisers = "0.2.11"
3737
ProgressLogging = "0.1"
3838
Reexport = "0.2, 1.0"
3939
SpecialFunctions = "1.8.2, 2.1.2"

docs/src/training/train_api.md

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,28 @@ Flux.Optimise.train!(loss, model, data, opt; cb)
99
To see one in a terminal, you will need to install [TerminalLoggers.jl](https://github.com/JuliaLogging/TerminalLoggers.jl)
1010
and follow its setup instructions.
1111

12-
The new version of Flux's training code was written as an independent package, called Optimisers.jl.
13-
However, at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s)
12+
The new version of Flux's training code was written as an independent package, [Optimisers.jl](https://github.com/FluxML/Optimisers.jl).
13+
This is designed to allow for immutable objects.
14+
But at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s)
1415
which can be updated in-place. Thus objects returned by `update!` can be ignored.
1516

1617
```@docs
1718
Optimisers.update!
1819
```
1920

21+
### Modifiers
22+
23+
The state returned by `setup` can be modified to temporarily prevent training of
24+
some parts of the model, or to change the learning rate uses.
25+
The functions for doing so may be accessed as `Flux.freeze!`, `Flux.thaw!`, and `Flux.adjust`:
26+
27+
```@docs
28+
Optimisers.adjust
29+
Optimisers.freeze!
30+
Optimisers.thaw!
31+
```
32+
33+
2034
## Implicit style
2135

2236
Flux used to handle gradients, training, and optimisation rules quite differently.

docs/src/training/training.md

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -326,15 +326,27 @@ The first, [`WeightDecay`](@ref) adds `0.42` times original parameter to the gra
326326
matching the gradient of the penalty above (with the same, unrealistically large, constant).
327327
After that, in either case, [`Adam`](@ref) computes the final update.
328328

329-
The same mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ).
329+
The same `OptimiserChain` mechanism can be used for other purposes, such as gradient clipping with [`ClipGrad`](@ref ).
330330

331331
Besides L2 / weight decay, another common and quite different kind of regularisation is
332-
provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some ... ??
333-
334-
?? do we discuss test/train mode here too?
332+
provided by the [`Dropout`](@ref Flux.Dropout) layer. This turns off some outputs of the
333+
previous layer during training.
334+
It should switch automatically, but see [trainmode!](@ref Flux.trainmode!) / [testmode!](@ref Flux.testmode!) to manually enable or disable this layer.
335335

336336
## Freezing, Schedules
337337

338-
?? maybe these also fit in here.
338+
Finer control of training
339+
340+
```julia
341+
model = Chain(enc = encoder, dec = decoder)
342+
343+
opt = Flux.setup(Adam(), model)
344+
345+
Flux.freeze!(opt.layers.enc) # corresponds to model.layers.end
346+
```
339347

348+
!!! note
349+
This `freeze!` goes with the "explicit" style.
350+
The earlier "implicit" equivalent was to pass to `gradient` an object referencing only
351+
part of the model, such as `Flux.params(model.layers.enc)`.
340352

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using MacroTools: @forward
88
@reexport using NNlib
99
using MLUtils
1010
import Optimisers: Optimisers, trainable, destructure # before v0.13, Flux owned these functions
11+
using Optimisers: freeze!, thaw!, adjust
1112

1213
using Zygote, ChainRulesCore
1314
using Zygote: Params, @adjoint, gradient, pullback, @nograd

0 commit comments

Comments
 (0)