Description
In an ideal world, a user would test an rrule
by writing something like the following, and have it work all of the time:
test_rrule(foo, args...)
By work all of the time, I mean that the tests that we want to run to determine the correctness of an rrule
implementation are always run successfully, provided that the function is something that know how to test (broadly speaking, the output is deterministic given the input), and for any input type that is either
- a primitive that we know about (
Real
,Array
, etc), or - a composite type.
It's important that this works automatically because we want people to be testing their code using CRTU, and people like to define new types (including new AbstractArray
s) and new functions. Unfortunately, I don't believe it's possible to automate in all cases, but the way in which it fails (AFAICT) is very specific, and I think we can document it and make it easy to resolve for users.
Roughly speaking, list of the functionality that always needs to always work in order to achieve this is
- to_vec
- to_vec_tangent (a new function)
- rand_tangent
- test_approx
to_vec
, to_vec_tangent
and rand_tangent
can be made to "always work", but test_approx
occassionally has a quirk that I don't believe that we can automate.
The outcome is the following proposals:
- remove all (or at least most)
to_vec
implementations in favour of the genericto_vec
implementation ofisstructtype
types, and necessaryto_vec
implementations forisprimitivetype
types, - introduce a
to_vec_tangent
(better name welcome) function, which is liketo_vec
, but the closure returned returns a tangent rather than a primal, - add a function called
remove_junk_data
, or something similar, which applies to primals, and returns another object which contains only the bits the primal relevant for definingisapprox
and whenever we test rules, we test the composition ofremove_junk_data
and the function being tested, rather than just the function. This enables us to definetest_approx
in a really generic manner.
I'll explain throughout this issue why I believe these are sensible proposals, and how they resolve things.
Additionally, while this proposal is independent from other proposed changes, it clearly favours a structural view of the world because I'm interested in automating things. See JuliaDiff/ChainRulesCore.jl#449 for a proposal for how we can do this without sacrificing usability, and how this leads to a precise definition for natural tangents.
I would be really interested to know if anyone thinks I've obviously missed something, or whether this sounds about right.
edit: I completely neglected constraint-related problems (eg. if the tangent provided to FiniteDifferences
needs to represent a positive definite matrix for some reason). AFAICT the things discussed are essentially orthogonal to the constraint problems though.
edit2: note: undefined references are not fun. For example, perfectly well-defined Dict
objects can contain undefined memory. I think this probably comes under the heading of "junk" data, but is seems to cause problems for to_vec
as it's currently defined. I wonder whether it could be generalised?
Example 1: Diagonal
size mismatch
Consider testing
f(x::Diagonal) = 5x
Let the output (co)tangent be
x = Diagonal(randn(2))
dx = Tangent{typeof(x)}(diag=randn(2))
then
FiniteDifferences.j′vp(central_fdm(5, 1), f, dx, x)
produces the error:
ERROR: DimensionMismatch("second dimension of A, 4, does not match length of x, 2")
Stacktrace:
[1] gemv!(y::Vector{Float64}, tA::Char, A::Matrix{Float64}, x::Vector{Float64}, α::Bool, β::Bool)
@ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:530
[2] mul!
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:97 [inlined]
[3] mul!
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
[4] *(transA::Transpose{Float64, Matrix{Float64}}, x::Vector{Float64})
@ LinearAlgebra /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:87
[5] _j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Vector{Float64}, x::Vector{Float64})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:80
[6] j′vp(fdm::FiniteDifferences.AdaptedFiniteDifferenceMethod{5, 1, FiniteDifferences.UnadaptedFiniteDifferenceMethod{7, 5}}, f::Function, ȳ::Tangent{Diagonal{Float64, Vector{Float64}}, NamedTuple{(:diag,), Tuple{Vector{Float64}}}}, x::Diagonal{Float64, Vector{Float64}})
@ FiniteDifferences ~/.julia/packages/FiniteDifferences/W3rQO/src/grad.jl:73
[7] top-level scope
@ REPL[24]:1
Why is this example a problem?
Firstly, we presently require that rules accept either a natural or structural tangent. Due to the above, it's not currently possible to test functions which output a Diagonal
with a Tangent
tangent.
Secondly, there exist Diagonal
matrices whose tangent cannot be represented by a Diagonal
. Specifically, any Diagonal
whose diag
field doesn't provide a way to produce an AbstractVector
as its tangent (i.e. for whatever reason, lacks a natural tangent). Consequently, in order for our testing facilities to handle any type, they must be able to work with structural tangents.
Finally, our current imlementations special case to_vec
for lots of different arrays (Diagonal
, Symmetric
etc). This is a problem in itself, but moreover we're never entirely sure what the right thing to do is when we encounter a new array.
How to fix this problem
Remove the specialised to_vec
methods for Diagonal
and other struct
AbstractArray
s (UpperTriangular
, Symmetric
, etc), and instead just rely on the generic to_vec
operation for struct
s.
Doing this immediately means that we can to_vec
anything that is either
- a primitive that we've defined
to_vec
on, or - any
struct
ormutable struct
.
This solution brings into focus a problem that we're currently solving on an ad-hoc basis in to_vec
: "junk" data in e.g. the lower triangle of a Symmetric
can wind up being used in approximate equality checks (and could in principle introduce non-determinism in an otherwise deterministic function, although I've yet to find an example of this in the wild), which makes no sense. We'll address this later.
Example 2: to_vec
gives the wrong type sometimes
to_vec
only knows about primals -- it knows nothing about tangents. The reason for this is because it was written when we also knew nowhere near enough about tangents, in particular for arrays. The particular problem is in the call to vec_to_x
on this line in j′vp
. It attempts to convert a "flat" vector representation of a cotangent into a primal. While this works fine in some cases (a surprisingly large number, given how much mileage we've gotten out of to_vec
over the years), we know that it doesn't work for all types.
Once you've removed the to_vec
implementations for the various concrete subtypes of AbstractArray
, you'll find that
FiniteDifferences.j′vp(central_fdm(5, 1), identity, dx, x)[1]
yields
2×2 Diagonal{Float64, Vector{Float64}}:
9.97302 ⋅
⋅ -0.329386
Why is this is a problem
Having the pullback for identity
return anything other than whatever cotangent it is provided seems highly undesirable to me, so I'm going to assume that our rrule
for identity
does just that. If that is the case, then the cotangent returned the pullback produced by that rrule
will be a Tangent
if the input is a Tangent
, not a Diagonal
. This means that our current implementation is incorrect. While this particular example seems reasonably benign, to my mind it's not correct. However, even if you believe it's correct, it's clearly only correct because Diagonal{Float64, Vector{Float64}}
s happens to have nice natural tangents that happen to be produced by to_vec
, rather than by design.
A more obviously incorrect / plainly-uninterpretable example is a Symmetric
-- the from_vec
output from to_vec(::Symmetric)
will produce a Symmetric
with an uplo
field that is a Char
. Since a Char
isn't an appropriate tangent for a Char
(it should be a NoTangent
), this is plainly nonsensical if the goal is to obtain a tangent. If you wound up comparing between this representation of the tangent and a Tangent
output from AD, you would need to compare a NoTangent
with this Char
, which under any sensible definition would fail (I can't imagine a world in which I would wish to reside in which NoTangent
is considered equal to a Char
).
How to fix this problem
Introduce another function to_vec_tangent
(better name would be nice) which returns a closure that always returns an appropriate tangent representation (primtive for primitives, structural for composites). This would require roughly the same level of implementation effort as to_vec
, and would mirror its structure almost entirely (specific methods for primitives, generic method for all isstructtype
types).
Example 3: Propagation of Junk Data
Consider
x = Symmetric(randn(2, 2))
dx = Tangent{typeof(x)}(data=randn(2, 2))
FiniteDifferences.j′vp(central_fdm(5, 1), identity, dx, x)[1].data
yields
2×2 Matrix{Float64}:
-0.180472 -0.793039
0.740994 0.900423
Note that the data
field is the relevant bit of the output from j′vp
here, because the consistent / correct interpretation of the thing that FiniteDifferences
outputs is a Tangent
, not a Symmetric
, as discussed in the previous example. Observe that the lower triangle (element (2, 1)
) will be used when test_approx
is computed, because the generic definition of test_approx
doesn't know about the specific semantics of Symmetric
. Since the standard libary makes no promises about the lower triangle of a Symmetric
, it seems to me intuitive that we shouldn't have to worry about it in our gradient definitions. I'm happy to expand on this, but there's a good example here.
The solution I believe is best is to ensure that the gradient w.r.t. irrelevant elements is always 0
by always testing
x -> remove_junk_data(f(x))
rather than just f
. The function remove_junk_data
would be defined such that it doesn't propagate any junk data (data which isn't relevant for equality computations). The implementations that I have so far are things like:
remove_junk_data(x::Number) = x
remove_junk_data(x::StridedArray) = map(remove_junk_data, x)
remove_junk_data(x::Symmetric{T, <:StridedArray{T}}) where {T} = collect(x)
remove_junk_data(x::UpperTriangular{T, <:StridedArray{T}}) where {T} = collect(x)
remove_junk_data(x::LowerTriangular{T, <:StridedArray{T}}) where {T} = collect(x)
function remove_junk_data(x::T) where {T}
Base.isstructtype(T) || throw(error("Expected a struct type"))
return map(remove_junk_data, fieldnames(T))
end
Another option one could consider is trying to define equality properly on Tangent
s. This isn't general though because e.g. the data
field of a Tangent{Symmetric}
might itself be a Tangent
, which doesn't have a conception of its own lower triangle. The benefit of composing with remove_junk_data
is that we get to operate on primal types, whose semantics everyone is familiar with (the data
field of a Symmetric
definitely does know about triangles because its an AbstractArray
and has getindex
defined).
So we can instruct type authors (ourselves for stdlib types) that if their types have any data that's essentially "junk" they must define a method of remove_junk_data
, and accept that we'll have to expend some extra computation internally to differentiate remove_junk_data
when testing (can probably be optimised away in most cases, since its the identity function in most cases).
Note that the generic fallback for composites means that we'll get overly restrictive tests by default, and type-authors have to opt-in to say that some bits of their type aren't important. This seems like the desirable way around to me -- I'd rather have tests yelling at me when they ought not to be, than them to fail to yell at me when they should.
Outcomes
Assuming that this pans out, this is a win-win for developers and users.
Developers get simpler, more robust, more straightforward to understand code with fewer edge cases -- the edge cases that remain have clear semantics and it's clear why they're necessary.
Users benefit from more predictable and reliable infrastructure.
The issue with this proposal is that it requires structural tangents to actually be taken seriously by everyone. Again, see JuliaDiff/ChainRulesCore.jl#449 for a discussion of how to make this more straightforward for all involved.