Skip to content

How to represent structural deriviatives #462

@oxinabox

Description

@oxinabox

Consider our current case.

@adjoint function Base.adjoint(x)
  back(Δ) = (Δ',)
  back(Δ::NamedTuple{(:parent,)}) = (Δ.parent,)
  return x', back
end

The structural deriviatives here is Δ::NamedTuple{(:parent,)
This gives correct behavour if the primal type matching that derivative was Adjoint.
It gives incorrect behavour if the primal type is some other wrapper that also has a .parent field.

Fortunately, Adjoint is an aberration of a wrapper type.
Almost all other wrapper types call that field .data and use the parent(...) accessor method.
But odds are somewhere there is one that uses parent as the field name

The question of this issue is how should we represent structual differentials.
I am strongly of the opinion that we should use ChainRules' Composite{Primal} type.
And if it is unsuitable that should be fixed in ChainRules.

I had been thinking of that change over as step 2 of integrating ChainRules, which is to change over to using ChainRule's types internally.
vs Step 1 #366 (changing over to using ChainRule's rules) would be done first.
But actually they are independent.
Using the types doesn't even require loading ChainRules -- just ChainRulesCore.

Note: this discussion is not about if we should be using structural differentials for array types in the first place. See JuliaDiff/ChainRulesCore.jl#85 and #445 for that.
I am sure we can find another example where there is no clear natural differential type.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions