-
-
Notifications
You must be signed in to change notification settings - Fork 608
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Enzyme train function #2446
Conversation
966be4a
to
3ce9e41
Compare
bumping @CarloLucibello or @ToucheSir for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks fine, and I guess the minimal version is to define this only in the tests.
Here's one idea for a public-facing API. We could have two methods like this:
train!(loss, model, data, opt) --> Zygote
train!(loss, model_and_shadow::Duplicated, data, opt) --> Enzyme
ideally with Enzyme.Duplicated(x) = Duplicated(x, Enzyme.make_zero(x))
so that you can call it train!(loss, Duplicated(model), data, opt)
.
That suggests also defining methods of withgradient
like this:
val, grads = Flux.withgradient(loss, model, data) --> Zygote
val, grads = Flux.withgradient(loss, Duplicated(model), data) --> Enzyme
That's a minimal change to select Enzyme. But unlike just passing some token like AutoEnzyme()
, this Duplicated
struct does other things... you can make in advance if you wish. And it's fairly obvious that you cannot do this without using Enzyme
.
We could go one further and define a method
update!(opt_state, model_and_grad::Duplicated)
That's a bigger change away from calling val, grads = Flux.withgradient(...
as you would discard what that returns, and hold onto the Duplicated
. But perhaps quite neat.
@mcabbott did the api change suggested. |
Thanks! Thoughts on defining one-arg |
I'm quite hesitant to doing so for a couple of reasons (including that duplicated isn't necessarily all you want and making an explicit second argument makes the user aware it is updating something else in place). Analagously, a one arg duplicated, if passed directly into autodiff like as follows: autodiff(Reverse, sum, Duplicated(x)); Of course duplicated will cause the shadow to be updated in place. But since the user didn't store the dval, they don't have the derivative available anywhere and will end up confused. Of course they could do duplicated(x) before the autodiff and store x.dval somewhere, but I'd like to avoid confusion by design if possible. |
@mcabbott the present API now hits the following. Thoughts? MethodError: train!(::var"#2104#loss#143"{Matrix{Float64}}, ::Duplicated{@NamedTuple{weight::Matrix{Float64}, bias::Vector{Float64}, ignore::Nothing}}, ::Base.Generator{UnitRange{Int64}, var"#140#144"}, ::Descent) is ambiguous.
Candidates:
train!(loss, model, data, opt::Flux.Optimise.AbstractOptimiser; cb)
@ Flux ~/work/Flux.jl/Flux.jl/src/deprecations.jl:110
train!(loss, model_and_shadow::Duplicated, data, opt_state)
@ Flux.Train ~/work/Flux.jl/Flux.jl/src/train.jl:124
Possible fix, define
train!(::Any, ::Duplicated, ::Any, ::Flux.Optimise.AbstractOptimiser) |
@darsnack @ToucheSir @mcabbott bumping for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good minus figuring out the dispatch issue and moving the docstring. Sorry the dispatch is a mess trying to keep support for the old optimizer interface.
We could consider merging this change in with dropping support for implicit optimizers completely.
@darsnack made the appropriate changes, mind giving it a final once over and merging? |
gentle bump |
For some reason the "allow edits by maintainers" is not letting me push to your fork. Can you manually add write permissions on your fork for me? |
@darsnack weird, but in any case added! |
@darsnack at least one of these failures is due to removing the custom rule for params |
Okay it looks ready to me. Can you give it a once over then I can merge it? |
@darsnack lgtm! After this we should open a tracking PR for seeing the status of Flux+CUDA/AMDGPU+Enzyme training |
A quick test on the readme input seems positive. I have no opinions on the design/API and I will give this PR to you all to make it however you feel (and I will go back to staring at CUDA).
I will note that perf atm is unclear and is worth investigating. However, before we do that, having a good way to run/test things is critical, hence this PR.