-
Notifications
You must be signed in to change notification settings - Fork 65
Description
ChainRules embraces multiple possible representations of cotangent, for example AbstractZero
, Composite
, and AbstractArray
are all valid representations for the cotangent of a Diagonal
. However, this flexibility results in an increased burden on rule implementers in that there is in principle no real upper bound on the number of types that one might have to accept as the cotangent w.r.t. the output of some function foo
that returns a Diagonal
.
I wonder whether some design-orthogonalisation might help to deal with this -- could we separate out the standardisation of the representation of cotangents from the rule implementation?
Consider a function canonicalise(primal, cotangent)
whose job it is to map a type onto a well-defined, predictable finite set of types for any given primal
type. For example, you might implement this as follows for Diagonal
:
canonicalise(::Any, dX::AbstractZero) = dX
canonicalise(X::Diagonal, dX::Composite{T}) = Composite{T}(diag=canonicalise(X.diag, dX.diag))
canonicalise(X::Diagonal, dX::AbstractMatrix) = Composite{T}(diag=canonicalise(X.diag, diag(dX)))
Note that I've chosen to make the canonical cotangent type for a Diagonal
a Composite
rather than an AbstractMatrix
for the usual performance related reasons discussed extensively in JuliaDiff/ChainRules.jl#232. An AbstractMatrix
doesn't count as a "canonical" type in my definition here since it's abstract, so doesn't meet the finiteness criterion.
If you did this, then we will certainly be able to avoid defining +
on so many things -- you just assume that things have been canonicalise
d before hitting +
. Similarly, Zygote
s automatic constructor pullback generation ought to have an easier time because, if you ensure that everything is appropriately canonicalised, constructors should always receive appropriate NamedTuple
s.
@sethaxen pointed out that this is something that we might want to concern ourselves with in #160, but I wanted to raise it separately, as I think it's an interesting thing to consider on its own.
edit: not sure whether we want to choose a different name from canonicalise
, given that we already have a function with that name. Possibly we could extend it to handle the more general class of things described here.