From 546ff165f620542ad014257012c6b2bc211f86e4 Mon Sep 17 00:00:00 2001 From: Billy Moses Date: Mon, 13 May 2024 18:39:27 -0700 Subject: [PATCH] Add Enzyme train function --- Project.toml | 1 + src/train.jl | 36 +++++++++++++++++++++++++++++++++++- test/train.jl | 28 +++++++++++++++++----------- 3 files changed, 53 insertions(+), 12 deletions(-) diff --git a/Project.toml b/Project.toml index 2283464f5e..201d52560e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,6 +6,7 @@ version = "0.14.15" Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" diff --git a/src/train.jl b/src/train.jl index e72eedebf3..76e708a536 100644 --- a/src/train.jl +++ b/src/train.jl @@ -5,8 +5,9 @@ using Optimisers: Optimisers using Functors: fmap, fmapstructure using ..Flux: Flux # used only in docstring import ..Flux.Optimise: train!, update! # during 0.13, we add methods to the old functions +import Enzyme -export setup, train! +export setup, train!, train_enzyme! using ProgressLogging: @progress, @withprogress, @logprogress using Zygote: Zygote, Params @@ -109,11 +110,44 @@ function train!(loss, model, data, opt; cb = nothing) end end +_make_zero!(x::AbstractArray) = fill!(x, 0) +_make_zero!(x) = x +make_zero!(model) = fmap(_make_zero!, model) + +applyloss(loss, model, d...) = loss(model, d...) + +""" + train_enzyme!(loss, model, data, opt::AbstractOptimiser; [cb]) + +Like [`train!](@ref), but gradient computed in place using [Enzyme](github.com/EnzymeAD/Enzyme.jl) +""" +function train_enzyme!(loss, model, data, opt; cb = nothing) + isnothing(cb) || error("""train_enzyme! does not support callback functions. + For more control use a loop with `gradient` and `update!`.""") + dmodel = Enzyme.make_zero(model) + @withprogress for (i,d) in enumerate(data) + d_splat = d isa Tuple ? d : (d,) + make_zero!(dmodel) + _, l = Enzyme.autodiff(Enzyme.ReverseWithPrimal, applyloss, Enzyme.Active, Enzyme.Const(loss), Enzyme.Duplicated(model, dmodel), map(Enzyme.Const, d_splat)...) + + if !isfinite(l) + throw(DomainError(lazy"Loss is $l on data item $i, stopping training")) + end + opt, model = Optimisers.update!(opt, model, dmodel) + @logprogress Base.haslength(data) ? i/length(data) : nothing + end +end + # This method let you use Optimisers.Descent() without setup, when there is no state function train!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) train!(loss, model, data, _rule_to_state(model, rule); cb) end +# This method let you use Optimisers.Descent() without setup, when there is no state +function train_enzyme!(loss, model, data, rule::Optimisers.AbstractRule; cb = nothing) + train_enzyme!(loss, model, data, _rule_to_state(model, rule); cb) +end + function _rule_to_state(model, rule::Optimisers.AbstractRule) state = setup(rule, model) @gensym warn_id diff --git a/test/train.jl b/test/train.jl index 1d938649d0..dec47cfdea 100644 --- a/test/train.jl +++ b/test/train.jl @@ -5,7 +5,8 @@ import Optimisers using Test using Random -@testset "Explicit Flux.train! with Zygote" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +@testset "Explicit Flux.train! with $name" begin Random.seed!(84) w = randn(10, 10) w2 = randn(10, 10) # NB outside the inner @testset, else it will be exactly == w, as the RNG seed is reset. @@ -18,7 +19,7 @@ using Random @test loss(model, rand(10, 10)) > 1 opt = Flux.setup(rule, model) - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end @@ -27,17 +28,19 @@ using Random loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) model = (weight=copy(w2), bias=zeros(10), ignore=nothing) @test loss(model, rand(10, 10)) > 1 - Flux.train!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) + trainfn!(loss, model, ((rand(10),) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end end +end -@testset "Explicit Flux.train! features" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +@testset "Explicit Flux.train! features with $name" begin @testset "Stop on NaN" begin m1 = Dense(1 => 1) m1.weight .= 0 CNT = 0 - @test_throws DomainError Flux.train!(m1, tuple.(1:100), Descent(0.1)) do m, i + @test_throws DomainError Flux.trainfn!(m1, tuple.(1:100), Descent(0.1)) do m, i CNT += 1 (i == 51 ? NaN32 : 1f0) * sum(m([1.0])) end @@ -51,16 +54,17 @@ end loss(m, x) = Flux.Losses.mse(w*x, m.weight*x .+ m.bias) model = (weight=copy(w2), bias=zeros(10)) opt = Flux.setup(AdamW(), model) - Flux.train!(loss, model, (rand(10) for _ in 1: 10^5), opt) + trainfn!(loss, model, (rand(10) for _ in 1: 10^5), opt) @test loss(model, rand(10, 10)) < 0.01 end @testset "callbacks give helpful error" begin m1 = Dense(1 => 1) cb = () -> println("this should not be printed") - @test_throws ErrorException Flux.train!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) + @test_throws ErrorException trainfn!((args...,) -> 1, m1, [(1,2)], Descent(0.1); cb) end end +end @testset "Explicit Flux.update! features" begin m = Chain(Dense(2=>3, tanh), Dense(3=>1), only) @@ -98,7 +102,8 @@ end @test y5 < y4 end -@testset "L2 regularisation" begin +for (trainfn!, name) in ((Flux.train!, "Zygote"), (Flux.train_enzyme!, "Enzyme")) +@testset "L2 regularisation with $name" begin # New docs claim an exact equivalent. It's a bit long to put the example in there, # but perhaps the tests should contain it. @@ -108,7 +113,7 @@ end # Take 1: explicitly add a penalty in the loss function opt = Flux.setup(Adam(0.1), model) - Flux.train!(model, data, opt) do m, x, y + trainfn!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) l2 = sum(abs2, m.weight)/2 + sum(abs2, m.bias)/2 err + 0.33 * l2 @@ -120,7 +125,7 @@ end model.bias .= 0 pen2(x::AbstractArray) = sum(abs2, x)/2 opt = Flux.setup(Adam(0.1), model) - Flux.train!(model, data, opt) do m, x, y + trainfn!(model, data, opt) do m, x, y err = Flux.mse(m(x), y) l2 = sum(pen2, Flux.params(m)) err + 0.33 * l2 @@ -132,12 +137,13 @@ end model.weight .= init_weight model.bias .= 0 decay_opt = Flux.setup(OptimiserChain(WeightDecay(0.33), Adam(0.1)), model); - Flux.train!(model, data, decay_opt) do m, x, y + trainfn!(model, data, decay_opt) do m, x, y Flux.mse(m(x), y) end diff3 = model.weight .- init_weight @test diff1 ≈ diff3 end +end @testset "Flux.setup bugs" begin # https://github.com/FluxML/Flux.jl/issues/2144