diff --git a/Project.toml b/Project.toml index e0d1318029..2cad212d03 100644 --- a/Project.toml +++ b/Project.toml @@ -41,8 +41,6 @@ ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" StatsBase = "0.33" -Tracker = "0.2.22" -Yota = "0.8.1" Zygote = "0.6.34" julia = "1.6" @@ -52,9 +50,7 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b" IterTools = "c8e1da08-722c-5040-9ed9-7db0dc04731e" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" -Yota = "cd998857-8626-517d-b929-70ad188a48f0" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays", "Tracker", "Yota"] +test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"] diff --git a/src/train.jl b/src/train.jl index 783536755b..046126328a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -6,7 +6,7 @@ using Functors: fmap import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions -export setup, @train_autodiff +export setup, train! using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote, Params @@ -14,28 +14,33 @@ using Zygote: Zygote, Params """ opt = setup(rule, model) -This is a version of `Optimisers.setup`, and is the first step before using `train!`. +This is a version of `Optimisers.setup`, and is the first step before using [`train!`](@ref Flux.train!). It differs from `Optimisers.setup` in that it: * has one extra check for mutability * has methods which accept Flux's old optimisers, and convert them. +# Example ```jldoctest julia> model = Dense(2=>1, leakyrelu; init=Flux.ones32); -julia> opt = Flux.setup(Momentum(0.11), model) -(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.0]), σ = ()) +julia> opt = Flux.setup(Momentum(0.1), model) # this encodes the optimiser and its state +(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0 0.0]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[0.0]), σ = ()) -julia> Flux.train!(model, opt) do m # 3-arg train!, for one data point (x = [0.2, -0.3], y = [0.4]) - sum(m([0.2, -0.3]) .- [0.4]) * 100 +julia> x1, y1 = [0.2, -0.3], [0.4]; # use the same data for two steps: + +julia> Flux.train!(model, [(x1, y1), (x1, y1)], opt) do m, x, y + sum(abs.(m(x) .- y)) * 100 end --40.1 +2-element Vector{Float32}: + 40.1 + 38.7 julia> model.bias # was zero, mutated by Flux.train! 1-element Vector{Float32}: - -0.11 + 10.190001 julia> opt # mutated by Flux.train! -(weight = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.022 -0.033]), bias = Leaf(Momentum{Float64}(0.11, 0.9), Float32[0.11]), σ = ()) +(weight = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-2.018 3.027]), bias = Leaf(Momentum{Float64}(0.1, 0.9), Float32[-10.09]), σ = ()) ``` """ function setup(rule::Optimisers.AbstractRule, model) @@ -51,18 +56,8 @@ end train!(loss, model, data, opt) Uses a `loss` function and training `data` to improve the `model`'s parameters -according to a particular optimisation rule `opt`. - -!!! note - This method has significant changes from the one in Flux ≤ 0.13: - * It now takes the `model` itself, not the result of [`Flux.params`](@ref). - (This is to move away from Zygote's implicit parameter handling.) - * Instead of `loss` being a function which typically accepts two arguments - (the input `x` and expected output `y` from each element of `data`) - now it should typically accept three, the first of which is the `model` itself. - * `data` must iterate tuples. Each `d in data` is used as `loss(model, d...)`. - * `opt` should be the result of [`Flux.setup`](@ref), it will warn you if not. - * Callback functions are not supported. +according to a particular optimisation rule `opt`. Iterates through `data` once, +evaluating `loss(model, d...)` for each `d` in data. For example, with these definitions... ``` @@ -72,15 +67,17 @@ loss3(m, x, y) = norm(m(x) .- y) # the model is the first argument opt = Flux.setup(Adam(), model) # explicit setup of optimiser momenta ``` -...calling `train!(loss3, model, data, opt)` runs a loop much like this: +...calling `Flux.train!(loss3, model, data, opt)` runs a loop much like this, +using Zygote's "explicit" mode for the gradient: ``` for d in data - ∂L∂m = Zygote.gradient(loss3, model, d...)[1] - Optimisers.update!(opt, model, ∂L∂m) + ∂L∂m = gradient(loss3, model, d...)[1] + update!(opt, model, ∂L∂m) # method for "explicit" gradient end ``` You can also write this loop yourself, if you need more flexibility. -Besides the loop, `train!` will: +For this reason `train!` is not highly extensible. +It adds only a few featurs to the loop above: * Stop with a `DomainError` if the loss is infinite or `NaN` at any point. @@ -91,20 +88,36 @@ Besides the loop, `train!` will: Note that the built-in loss functions accept 3 arguments, allowing for instance `train!(Flux.Losses.mse, model, data, opt)` instead of defining `loss3` as above. -Note that callback functions are not supported. But arbitrary code can be inserted into the loop. +!!! note + This method has significant changes from the one in Flux ≤ 0.13: + * It now takes the `model` itself, not the result of [`Flux.params`](@ref). + (This is to move away from Zygote's "implicit" parameter handling, with `Grads`.) + * Instead of `loss` being a function which typically accepts two arguments + (the input `x` and expected output `y` from each element of `data`) + now it should typically accept three, the first of which is the `model` itself. + * `data` must iterate tuples, otherwise you get an error. + (Previously non-tuple types were not splatted into the loss. + Pass in `((d,) for d in data)` to simulate this.) + * `opt` should be the result of [`Flux.setup`](@ref). Using an optimiser + such as `Adam()` without this step should give you a warning. + * Callback functions are not supported. + But any code can be included in the above `for` loop. """ -function train!(loss, model, data, opt) +function train!(loss, model, data, opt; cb = nothing) + isnothing(cb) || error("""train! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") losses = Float32[] @withprogress for (i,d) in enumerate(data) d isa Tuple || error("""train! expects as data an iterator producing tuples, but got $(typeof(d)). Pass it `((d,) for d in data)`, or use `gradient` and `update!` for more control.""") - l, (g, _...) = explicit_withgradient(loss, model, d...) + # l, (g, _...) = explicit_withgradient(loss, model, d...) # BTW this un-thunks gradient w.r.t. data. Could avoid that + l, (g, _...) = explicit_withgradient(m -> loss(m, d...), model) isfinite(l) || throw(DomainError("loss function returned $l, stopping training")) opt, model = Optimisers.update!(opt, model, g) push!(losses, l) @logprogress Base.haslength(data) ? i/length(data) : nothing end - return losses # Not entirely sure returning losses is a good idea + return losses # Not entirely sure returning losses is a good idea, as it may conflict with later returning immutable models alla Optimisers.jl end # This method let you use Optimisers.Descent() without setup, when there is no state