Skip to content

Commit

Permalink
FAQ: What types does my pullback need to accept? (#428)
Browse files Browse the repository at this point in the history
* FAQ: What types does my pullback need to accept?

* Update docs/src/FAQ.md

* move pullback types into writing good rules

* mention that natural tangent does not have a formal definition

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>

* Update docs/src/writing_good_rules.md

* Apply suggestions from code review

Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>

* Update docs/src/writing_good_rules.md

* Update docs/src/writing_good_rules.md

Co-authored-by: willtebbutt <wt0881@my.bristol.ac.uk>
Co-authored-by: Miha Zgubic <mzgubic@users.noreply.github.com>
  • Loading branch information
3 people authored Aug 17, 2021
1 parent 09c133b commit 2208660
Show file tree
Hide file tree
Showing 2 changed files with 66 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/src/FAQ.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,4 +111,4 @@ This is morally the same as similar issues [discussed in ColPrac](https://github

On a practical level, it's important that this is the case because thunks are a bit of a hack,
and over time it is hoped that the need for them will reduce, as they increase
code-complexity and place additional stress on the compiler.
code-complexity and place additional stress on the compiler.
65 changes: 65 additions & 0 deletions docs/src/writing_good_rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,71 @@ end
```
to define the rules.

## Ensure your pullback can accept the right types
As a rule the number of types you need to accept in a pullback is theoretically unlimitted, but practically highly constrained to be in line with the primal return type.
The three kinds of inputs you will practically need to accept one or more of: _natural tangents_, _structural tangents_, and _thunks_.
You do not in general have to handle `AbstractZero`s as the AD system will not call the pullback if the input is a zero, since the output will also be.
Some more background information on these types can be found in [the design notes](@ref manytypes).
In many cases all these tangents can be treated the same: tangent types overload a bunch of linear-operators, and the majority of functions used inside a pullback are linear operators.
If you find linear operators from Base/stdlibs that are not supported, consider opening an issue or a PR on the [ChainRulesCore.jl repo](https://github.com/JuliaDiff/ChainRulesCore.jl/).

### Natural tangents
Natural tangent types are the types you might feel the tangent should be, to represent a small change in the primal value.
For example, if the primal is a `Float32`, the natural tangent is also a `Float32`.
Slightly more complex, for a `ComplexF64` the natural tangent is again also a `ComplexF64`, we almost never want to use the structural tangent `Tangent{ComplexF64}(re=..., im=...)` which is defined.
For other cases, this gets a little more complicated, see below.
These are a purely human notion, they are the types the user wants to use because they make the math easy.
There is currently no formal definition of what constitutes a natural tangent, but there are a few heuristics.
For example, if a primal type `P` overloads subtraction (`-(::P,::P)`) then that generally returns a natural tangent type for `P`; but this is not required to be defined and sometimes it is defined poorly.

Common cases for types that represent a [vector-space](https://en.wikipedia.org/wiki/Vector_space) (e.g. `Float64`, `Array{Float64}`) is that the natural tangent type is the same as the primal type.
However, this is not always the case.
For example for a [`PDiagMat`](https://github.com/JuliaStats/PDMats.jl) a natural tangent is `Diagonal` since there is no requirement that a positive definite diagonal matrix has a positive definite tangent.
Another example is for a `DateTime`, any `Period` subtype, such as `Millisecond` or `Nanosecond` is a natural differential.
There are often many different natural tangent types for a given primal type.
However, they are generally closely related and duck-type the same.
For example, for most `AbstractArray` subtypes, most other `AbstractArray`s (of right size and element type) can be considered as natural tangent types.

Not all types have natural tangent types.
For example there is no natural differential for a `Tuple`.
It is not a `Tuple` since that doesn't have any method for `+`.
Similar is true for many `struct`s.
For those cases there is only a structural differential.

### Structural tangents

Structural tangents are tangent types that shadow the structure of the primal type.
They are represented by the [`Tangent`](@ref) type.
They can represent any composite type, such as a tuple, or a structure (or a `NamedTuple`) etc.


!!! info "Do I have to support the structural tangents as well?"
Technically, you might not actually have to write rules to accept structural tangents; if the AD system never has to decompose down to the level of `getfield`.
This is common for types that don't support user `getfield`/`getproperty` access, and that have a lot of rules for the ways they are accessed (such cases include some `AbstractArray` subtypes).
You really should support it just in case; especially if the primal type in question is not restricted to a well-tested concrete type.
But if it is causing struggles, then you can leave it off til someone complains.

### Thunks

A thunk (either a [`Thunk`](@ref), or a [`InplaceableThunk`](@ref)), represents a delayed computation.
They can be thought of as a wrapper of the value the computation returns.
In this sense they wrap either a natural or structural tangent.

!!! warning "You should to support AbstractThunk inputs even if you don't use thunks"
Unfortunately the AD sytems do not know which rules support thunks and which do not.
So all rules have to; at least if they want to play nice with arbitary AD systems.
Luckily it is not hard: much of the time they will duck-type as the object they wrap.
If not, then just add a [`unthunk`](@ref) after the start of your pullback.
(Even when they do duck-type, if they are used multiple times then unthunking at the start will prevent them from being recomputed.)
If you are using [`@thunk`](@ref) and the input is only needed for one of them then the `unthunk` should be in that one.
If not, and you have a bunch of pullbacks you might like to write a little helper `unthunking(f) = x̄ -> f(unthunk(x̄))` that you can wrap your pullback function in before returning it from the `rrule`.
Yes, this is a bit of boiler-plate, and it is unfortunate.
Sadly, it is needed because if the AD wants to benefit it can't get that benifit unless things are not unthunked unnecessarily.
Which eventually allows them in some cases to never be unthunked at all.
There are two ways common things are never unthunked.
One is if the unthunking happens inside a `@thunk` which is never unthunked itself because it is the tangent for a primal input that never has it's tangent queried.
The second is if they are not unthunked because the rule does not need to know what is inside: consider the pullback for `identity`: `x̄ -> (NoTangent(), x̄)`.

## Use `@not_implemented` appropriately

One can use [`@not_implemented`](@ref) to mark missing differentials.
Expand Down

0 comments on commit 2208660

Please sign in to comment.