Add more Duplicated
methods for Enzyme.jl support
#2471
+309
−19
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This adds a method like
gradient(f, ::Duplicated)
which liketrain!(loss, model::Duplicated, data, opt)
from #2446 uses the Duplicated type to signal that you want to use Enzyme not Zygote. It returns the gradient (for compatibility?) and mutates theDuplicated
object.To avoid piracy, this creates a new function
Flux.gradient
which by default callsZygote.gradient
. Unfortunately that's going to mean everyusing Flux, Zygote
now produces ambiguities... so probably it should not be exported? Which means 0.15.There's also
withgradient
but it doesn't allow you to return a tuple the way Zygote does, not yet.There's also a method of
update!
which either needs to move to Optimisers.jl, or again we need to let Flux own the function.Finally,
@layer Chain
defines a 1-argumentDuplicated(c::Chain)
method, so that you don't need to construct the dual by hand.WIP, RFC?
Needs tests, and docs.
PR Checklist