Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to specify different rules based on what combinations of inputs are actually being used #452

Open
oxinabox opened this issue Sep 1, 2021 · 1 comment
Labels
design Requires some desgin before changes are made

Comments

@oxinabox
Copy link
Member

oxinabox commented Sep 1, 2021

It would be nice to have an extension to the ChainRules API that allowed for different rules to be written and hit depending on what combination of inputs the derivative is being taken with respect to.
(or maybe similar analogy for forwards mode)
Basically partial derivative rules.

Thunking is a simple approximation to this (with it's own set of struggles)

@wsmoses requested for Enzyme, though this is also relevant to every kind of operator overloading based AD (since only tracked types etc will have derivatives taken wrt to them).
In contrast it is useless for Zygote/Diffractor as they do no kind of activity analysis etc, and transform absolutely all code that is run.

A bit of a sketch for what that API might look-like is in https://gist.github.com/oxinabox/c6ad25c468b3108f8a799bda66c147f8/

This might also be useful for partial mutation support, since it is probably completely safe to have rules for things that mutate inputs that are not "active" on the derivative path? (cf JuliaDiff/ChainRules.jl#521)
Though as the main reason we don't do mutation is tied to Diffractor/Zygote not supporting it, that might be kinda moot, unless they got some at least some basic activity analysis.

(NB: we may not initially implement this in ChainRulesCore. It might be better to make a little experimental extension package for it first.)

@oxinabox oxinabox added the design Requires some desgin before changes are made label Sep 1, 2021
@piever
Copy link
Contributor

piever commented Sep 10, 2021

In terms of API, I was thinking that one could pass the values in the cotangent on which to accumulate the pullback. So for example, for a function

z = f(x, y)

The pullback would always have a signature

# passing cotangent vector in codomain of `f` as well as cotangent vectors in domain of `f` on which to accumulate
function f_pullback(z̄, x̄, ȳ) # or some other signature, like `f_pullback(z̄, input_cotangents=(x̄, ȳ))`

end

with the idea that:

  1. If (or ) are AbstractZero, simply return the regular computed_cotangents
  2. If (or ) are abstract arrays, you are free to overwrite (so return add!!(x̄, computed_cotangent))
  3. If (or ) are nothing, return nothing in their slot and save computations (not 100% familiar with the internals, this may need to be NoTangent(), and the first point is ZeroTangent())

This way, both Thunk and InplaceableThunk are replaced by the new API, and they are "asked" by the autodiff engine, which can choose to pass AbstractZero, AbstractArray or Nothing to signal "just do the normal thing" vs "please try to accumulate in place" vs "feel free to not compute this one".

Writing the average rule becomes slightly more tiresome, but one could always add fallbacks that do the "less optimal thing" but for free.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
design Requires some desgin before changes are made
Projects
None yet
Development

No branches or pull requests

2 participants