You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
It would be nice to have an extension to the ChainRules API that allowed for different rules to be written and hit depending on what combination of inputs the derivative is being taken with respect to.
(or maybe similar analogy for forwards mode)
Basically partial derivative rules.
Thunking is a simple approximation to this (with it's own set of struggles)
@wsmoses requested for Enzyme, though this is also relevant to every kind of operator overloading based AD (since only tracked types etc will have derivatives taken wrt to them).
In contrast it is useless for Zygote/Diffractor as they do no kind of activity analysis etc, and transform absolutely all code that is run.
This might also be useful for partial mutation support, since it is probably completely safe to have rules for things that mutate inputs that are not "active" on the derivative path? (cf JuliaDiff/ChainRules.jl#521)
Though as the main reason we don't do mutation is tied to Diffractor/Zygote not supporting it, that might be kinda moot, unless they got some at least some basic activity analysis.
(NB: we may not initially implement this in ChainRulesCore. It might be better to make a little experimental extension package for it first.)
The text was updated successfully, but these errors were encountered:
In terms of API, I was thinking that one could pass the values in the cotangent on which to accumulate the pullback. So for example, for a function
z =f(x, y)
The pullback would always have a signature
# passing cotangent vector in codomain of `f` as well as cotangent vectors in domain of `f` on which to accumulatefunctionf_pullback(z̄, x̄, ȳ) # or some other signature, like `f_pullback(z̄, input_cotangents=(x̄, ȳ))`end
with the idea that:
If x̄ (or ȳ) are AbstractZero, simply return the regular computed_cotangents
If x̄ (or ȳ) are abstract arrays, you are free to overwrite (so return add!!(x̄, computed_cotangent))
If x̄ (or ȳ) are nothing, return nothing in their slot and save computations (not 100% familiar with the internals, this may need to be NoTangent(), and the first point is ZeroTangent())
This way, both Thunk and InplaceableThunk are replaced by the new API, and they are "asked" by the autodiff engine, which can choose to pass AbstractZero, AbstractArray or Nothing to signal "just do the normal thing" vs "please try to accumulate in place" vs "feel free to not compute this one".
Writing the average rule becomes slightly more tiresome, but one could always add fallbacks that do the "less optimal thing" but for free.
It would be nice to have an extension to the ChainRules API that allowed for different rules to be written and hit depending on what combination of inputs the derivative is being taken with respect to.
(or maybe similar analogy for forwards mode)
Basically partial derivative rules.
Thunking is a simple approximation to this (with it's own set of struggles)
@wsmoses requested for Enzyme, though this is also relevant to every kind of operator overloading based AD (since only tracked types etc will have derivatives taken wrt to them).
In contrast it is useless for Zygote/Diffractor as they do no kind of activity analysis etc, and transform absolutely all code that is run.
A bit of a sketch for what that API might look-like is in https://gist.github.com/oxinabox/c6ad25c468b3108f8a799bda66c147f8/
This might also be useful for partial mutation support, since it is probably completely safe to have rules for things that mutate inputs that are not "active" on the derivative path? (cf JuliaDiff/ChainRules.jl#521)
Though as the main reason we don't do mutation is tied to Diffractor/Zygote not supporting it, that might be kinda moot, unless they got some at least some basic activity analysis.
(NB: we may not initially implement this in ChainRulesCore. It might be better to make a little experimental extension package for it first.)
The text was updated successfully, but these errors were encountered: