Skip to content

Commit

Permalink
Add Enzyme train function
Browse files Browse the repository at this point in the history
  • Loading branch information
wsmoses committed May 14, 2024
1 parent 89ecf4c commit 546ff16
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 12 deletions.
1 change: 1 addition & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.14.15"
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
Expand Down
36 changes: 35 additions & 1 deletion src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@ using Optimisers: Optimisers
using Functors: fmap, fmapstructure
using ..Flux: Flux # used only in docstring
import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions
import Enzyme

export setup, train!
export setup, train!, train_enzyme!

using ProgressLogging: @progress, @withprogress, @logprogress
using Zygote: Zygote, Params
Expand Down Expand Up @@ -109,11 +110,44 @@ function train!(loss, model, data, opt; cb = nothing)
end
end

_make_zero!(x::AbstractArray) = fill!(x, 0)
_make_zero!(x) = x
make_zero!(model) = fmap(_make_zero!, model)

applyloss(loss, model, d...) = loss(model, d...)

"""
train_enzyme!(loss, model, data, opt::AbstractOptimiser; [cb])
Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl)
"""
function train_enzyme!(loss, model, data, opt; cb = nothing)
isnothing(cb) || error("""train_enzyme! does not support callback functions.
For more control use a loop with `gradient` and `update!`.""")
dmodel = Enzyme.make_zero(model)
@withprogress for (i,d) in enumerate(data)
d_splat = d isa Tuple ? d : (d,)
make_zero!(dmodel)
_, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, applyloss, Enzyme.Active, Enzyme.Const(loss), Enzyme.Duplicated(model, dmodel), map(Enzyme.Const, d_splat)...)

if !isfinite(l)
throw(DomainError(lazy"Loss is $l on data item $i, stopping training"))
end
opt, model = Optimisers.update!(opt, model, dmodel)
@logprogress Base.haslength(data) ? i/length(data) : nothing
end
end

# This method let you use Optimisers.Descent() without setup, when there is no state
function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
train!(loss, model, data, _rule_to_state(model, rule); cb)
end

# This method let you use Optimisers.Descent() without setup, when there is no state
function train_enzyme!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing)
train_enzyme!(loss, model, data, _rule_to_state(model, rule); cb)
end

function _rule_to_state(model, rule::Optimisers.AbstractRule)
state = setup(rule, model)
@gensym warn_id
Expand Down
28 changes: 17 additions & 11 deletions test/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@ import Optimisers
using Test
using Random

@testset "Explicit Flux.train! with Zygote" begin
for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme"))
@testset "Explicit Flux.train! with $name" begin
Random.seed!(84)
w = randn(10, 10)
w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset.
Expand All @@ -18,7 +19,7 @@ using Random
@test loss(model, rand(10, 10)) > 1

opt = Flux.setup(rule, model)
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end

Expand All @@ -27,17 +28,19 @@ using Random
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10), ignore=nothing)
@test loss(model, rand(10, 10)) > 1
Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end
end
end

@testset "Explicit Flux.train! features" begin
for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme"))
@testset "Explicit Flux.train! features with $name" begin
@testset "Stop on NaN" begin
m1 = Dense(1 => 1)
m1.weight .= 0
CNT = 0
@test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i
@test_throws DomainError Flux.trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i
CNT += 1
(i == 51 ? NaN32 : 1f0) * sum(m([1.0]))
end
Expand All @@ -51,16 +54,17 @@ end
loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias)
model = (weight=copy(w2), bias=zeros(10))
opt = Flux.setup(AdamW(), model)
Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt)
trainfn!(loss, model, (rand(10) for _ in 1: 10^5), opt)
@test loss(model, rand(10, 10)) < 0.01
end

@testset "callbacks give helpful error" begin
m1 = Dense(1 => 1)
cb = () -> println("this should not be printed")
@test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb)
@test_throws ErrorException trainfn!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb)
end
end
end

@testset "Explicit Flux.update! features" begin
m = Chain(Dense(2=>3, tanh), Dense(3=>1), only)
Expand Down Expand Up @@ -98,7 +102,8 @@ end
@test y5 < y4
end

@testset "L2 regularisation" begin
for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme"))
@testset "L2 regularisation with $name" begin
# New docs claim an exact equivalent. It's a bit long to put the example in there,
# but perhaps the tests should contain it.

Expand All @@ -108,7 +113,7 @@ end

# Take 1: explicitly add a penalty in the loss function
opt = Flux.setup(Adam(0.1), model)
Flux.train!(model, data, opt) do m, x, y
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2
err + 0.33 * l2
Expand All @@ -120,7 +125,7 @@ end
model.bias .= 0
pen2(x::AbstractArray) = sum(abs2, x)/2
opt = Flux.setup(Adam(0.1), model)
Flux.train!(model, data, opt) do m, x, y
trainfn!(model, data, opt) do m, x, y
err = Flux.mse(m(x), y)
l2 = sum(pen2, Flux.params(m))
err + 0.33 * l2
Expand All @@ -132,12 +137,13 @@ end
model.weight .= init_weight
model.bias .= 0
decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model);
Flux.train!(model, data, decay_opt) do m, x, y
trainfn!(model, data, decay_opt) do m, x, y
Flux.mse(m(x), y)
end
diff3 = model.weight .- init_weight
@test diff1 diff3
end
end

@testset "Flux.setup bugs" begin
# https://github.com/FluxML/Flux.jl/issues/2144
Expand Down

0 comments on commit 546ff16

Please sign in to comment.