-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Making Testing More Automatic #213
Comments
Another way to go about the equality problem is to avoid defining approximate equality on tangents entirely, and instead rely on
i.e. implement dx_ad = _compute_ad_cotangent
dx_fd = _compute_fd_cotangent
x_ad = x + dx_ad
x_fd = x + dx_fd
test_approx(x_ad, x_fd) While edit: the issue with this proposal is, of course, that you can't always add a tangent to a primal, because constraints. |
Sketch implementation for function to_vec_tangent(x::Real)
Real_Tangent_from_vec(x_vec) = first(x_vec)
return [x], Real_Tangent_from_vec
end
function to_vec_tangent(z::Complex)
Complex_Tangent_from_vec(z_vec) = Complex(z_vec[1], z_vec[2])
return [real(z), imag(z)], Complex_Tangent_from_vec
end
to_vec_tangent(x::Vector{<:Union{Real, Complex}}) = (x, identity)
function to_vec_tangent(x::Vector)
x_vecs_and_backs = map(to_vec_tangent, x)
x_vecs, backs = first.(x_vecs_and_backs), last.(x_vecs_and_backs)
function Vector_Tangent_from_vec(x_vec)
sz = cumsum(map(length, x_vecs))
x_Vec = [backs[n](x_vec[sz[n] - length(x_vecs[n]) + 1:sz[n]]) for n in eachindex(x)]
return x_Vec
end
# handle empty x
x_vec = isempty(x_vecs) ? eltype(eltype(x_vecs))[] : reduce(vcat, x_vecs)
return x_vec, Vector_Tangent_from_vec
end
function to_vec_tangent(x::Array)
x_vec, Tangent_from_vec = to_vec_tangent(vec(x))
function Array_Tangent_from_vec(x_vec)
return collect(reshape(Tangent_from_vec(x_vec), size(x)))
end
return x_vec, Array_Tangent_from_vec
end
to_vec_tangent(x::Char) = (Bool[], _ -> x)
# Any struct ought to be interpretable as a Tangent, regardless inner constructors etc.
function to_vec_tangent(x::T) where {T}
Base.isstructtype(T) || throw(error("Expected a struct type"))
isempty(fieldnames(T)) && return (Bool[], _ -> x) # Singleton types
val_vecs_and_backs = map(name -> to_vec_tangent(getfield(x, name)), fieldnames(T))
vals = first.(val_vecs_and_backs)
backs = last.(val_vecs_and_backs)
v, Tangents_from_vec = to_vec_tangent(vals)
function structtype_Tangent_from_vec(v::Vector{<:Real})
val_vecs = Tangents_from_vec(v)
tangents = map((b, v) -> b(v), backs, val_vecs)
return Tangent{T}(NamedTuple(zip(fieldnames(T), tangents))...)
end
return v, structtype_Tangent_from_vec
end Seems like the generic method will be slightly simpler than |
I also would like to get rid of It may be that we need to incorporate a new thing that helps teach it how to perturb a object though all it's basis points, where each perturbation returns a ChainRules type. (From there I would like to be using I suspect the easy fix for issues with |
Also my thoughts. The only problem you get is the standard "can perturb out of the set of types that can be represented" problem.
Agreed -- this would be nice. However, my expectation is that it would be a non-trivial amount of work to make that work properly, while implementing
Good point. I've editted my first comment to point out that I forgot about these kinds of problems. I agree they're important, but I think they're basically orthogonal to what I'm discussing here. Please correct me if I'm wrong. |
Fun story: I managed to get the default On balance, probably not worth worrying about undefined / unassigned entries unless we find more examples where they matter. |
One thing to point out about Example 1 is that the "reverse" problem is also there: a primal output that is a dense matrix will error with the same dimension mismatch if it receives a |
In an ideal world, a user would test an
rrule
by writing something like the following, and have it work all of the time:By work all of the time, I mean that the tests that we want to run to determine the correctness of an
rrule
implementation are always run successfully, provided that the function is something that know how to test (broadly speaking, the output is deterministic given the input), and for any input type that is eitherReal
,Array
, etc), orIt's important that this works automatically because we want people to be testing their code using CRTU, and people like to define new types (including new
AbstractArray
s) and new functions. Unfortunately, I don't believe it's possible to automate in all cases, but the way in which it fails (AFAICT) is very specific, and I think we can document it and make it easy to resolve for users.Roughly speaking, list of the functionality that always needs to always work in order to achieve this is
to_vec
,to_vec_tangent
andrand_tangent
can be made to "always work", buttest_approx
occassionally has a quirk that I don't believe that we can automate.The outcome is the following proposals:
to_vec
implementations in favour of the genericto_vec
implementation ofisstructtype
types, and necessaryto_vec
implementations forisprimitivetype
types,to_vec_tangent
(better name welcome) function, which is liketo_vec
, but the closure returned returns a tangent rather than a primal,remove_junk_data
, or something similar, which applies to primals, and returns another object which contains only the bits the primal relevant for definingisapprox
and whenever we test rules, we test the composition ofremove_junk_data
and the function being tested, rather than just the function. This enables us to definetest_approx
in a really generic manner.I'll explain throughout this issue why I believe these are sensible proposals, and how they resolve things.
Additionally, while this proposal is independent from other proposed changes, it clearly favours a structural view of the world because I'm interested in automating things. See JuliaDiff/ChainRulesCore.jl#449 for a proposal for how we can do this without sacrificing usability, and how this leads to a precise definition for natural tangents.
I would be really interested to know if anyone thinks I've obviously missed something, or whether this sounds about right.
edit: I completely neglected constraint-related problems (eg. if the tangent provided to
FiniteDifferences
needs to represent a positive definite matrix for some reason). AFAICT the things discussed are essentially orthogonal to the constraint problems though.edit2: note: undefined references are not fun. For example, perfectly well-defined
Dict
objects can contain undefined memory. I think this probably comes under the heading of "junk" data, but is seems to cause problems forto_vec
as it's currently defined. I wonder whether it could be generalised?Example 1:
Diagonal
size mismatchConsider testing
Let the output (co)tangent be
then
produces the error:
Why is this example a problem?
Firstly, we presently require that rules accept either a natural or structural tangent. Due to the above, it's not currently possible to test functions which output a
Diagonal
with aTangent
tangent.Secondly, there exist
Diagonal
matrices whose tangent cannot be represented by aDiagonal
. Specifically, anyDiagonal
whosediag
field doesn't provide a way to produce anAbstractVector
as its tangent (i.e. for whatever reason, lacks a natural tangent). Consequently, in order for our testing facilities to handle any type, they must be able to work with structural tangents.Finally, our current imlementations special case
to_vec
for lots of different arrays (Diagonal
,Symmetric
etc). This is a problem in itself, but moreover we're never entirely sure what the right thing to do is when we encounter a new array.How to fix this problem
Remove the specialised
to_vec
methods forDiagonal
and otherstruct
AbstractArray
s (UpperTriangular
,Symmetric
, etc), and instead just rely on the genericto_vec
operation forstruct
s.Doing this immediately means that we can
to_vec
anything that is eitherto_vec
on, orstruct
ormutable struct
.This solution brings into focus a problem that we're currently solving on an ad-hoc basis in
to_vec
: "junk" data in e.g. the lower triangle of aSymmetric
can wind up being used in approximate equality checks (and could in principle introduce non-determinism in an otherwise deterministic function, although I've yet to find an example of this in the wild), which makes no sense. We'll address this later.Example 2:
to_vec
gives the wrong type sometimesto_vec
only knows about primals -- it knows nothing about tangents. The reason for this is because it was written when we also knew nowhere near enough about tangents, in particular for arrays. The particular problem is in the call tovec_to_x
on this line inj′vp
. It attempts to convert a "flat" vector representation of a cotangent into a primal. While this works fine in some cases (a surprisingly large number, given how much mileage we've gotten out ofto_vec
over the years), we know that it doesn't work for all types.Once you've removed the
to_vec
implementations for the various concrete subtypes ofAbstractArray
, you'll find thatyields
Why is this is a problem
Having the pullback for
identity
return anything other than whatever cotangent it is provided seems highly undesirable to me, so I'm going to assume that ourrrule
foridentity
does just that. If that is the case, then the cotangent returned the pullback produced by thatrrule
will be aTangent
if the input is aTangent
, not aDiagonal
. This means that our current implementation is incorrect. While this particular example seems reasonably benign, to my mind it's not correct. However, even if you believe it's correct, it's clearly only correct becauseDiagonal{Float64, Vector{Float64}}
s happens to have nice natural tangents that happen to be produced byto_vec
, rather than by design.A more obviously incorrect / plainly-uninterpretable example is a
Symmetric
-- thefrom_vec
output fromto_vec(::Symmetric)
will produce aSymmetric
with anuplo
field that is aChar
. Since aChar
isn't an appropriate tangent for aChar
(it should be aNoTangent
), this is plainly nonsensical if the goal is to obtain a tangent. If you wound up comparing between this representation of the tangent and aTangent
output from AD, you would need to compare aNoTangent
with thisChar
, which under any sensible definition would fail (I can't imagine a world in which I would wish to reside in whichNoTangent
is considered equal to aChar
).How to fix this problem
Introduce another function
to_vec_tangent
(better name would be nice) which returns a closure that always returns an appropriate tangent representation (primtive for primitives, structural for composites). This would require roughly the same level of implementation effort asto_vec
, and would mirror its structure almost entirely (specific methods for primitives, generic method for allisstructtype
types).Example 3: Propagation of Junk Data
Consider
yields
Note that the
data
field is the relevant bit of the output fromj′vp
here, because the consistent / correct interpretation of the thing thatFiniteDifferences
outputs is aTangent
, not aSymmetric
, as discussed in the previous example. Observe that the lower triangle (element(2, 1)
) will be used whentest_approx
is computed, because the generic definition oftest_approx
doesn't know about the specific semantics ofSymmetric
. Since the standard libary makes no promises about the lower triangle of aSymmetric
, it seems to me intuitive that we shouldn't have to worry about it in our gradient definitions. I'm happy to expand on this, but there's a good example here.The solution I believe is best is to ensure that the gradient w.r.t. irrelevant elements is always
0
by always testingrather than just
f
. The functionremove_junk_data
would be defined such that it doesn't propagate any junk data (data which isn't relevant for equality computations). The implementations that I have so far are things like:Another option one could consider is trying to define equality properly on
Tangent
s. This isn't general though because e.g. thedata
field of aTangent{Symmetric}
might itself be aTangent
, which doesn't have a conception of its own lower triangle. The benefit of composing withremove_junk_data
is that we get to operate on primal types, whose semantics everyone is familiar with (thedata
field of aSymmetric
definitely does know about triangles because its anAbstractArray
and hasgetindex
defined).So we can instruct type authors (ourselves for stdlib types) that if their types have any data that's essentially "junk" they must define a method of
remove_junk_data
, and accept that we'll have to expend some extra computation internally to differentiateremove_junk_data
when testing (can probably be optimised away in most cases, since its the identity function in most cases).Note that the generic fallback for composites means that we'll get overly restrictive tests by default, and type-authors have to opt-in to say that some bits of their type aren't important. This seems like the desirable way around to me -- I'd rather have tests yelling at me when they ought not to be, than them to fail to yell at me when they should.
Outcomes
Assuming that this pans out, this is a win-win for developers and users.
Developers get simpler, more robust, more straightforward to understand code with fewer edge cases -- the edge cases that remain have clear semantics and it's clear why they're necessary.
Users benefit from more predictable and reliable infrastructure.
The issue with this proposal is that it requires structural tangents to actually be taken seriously by everyone. Again, see JuliaDiff/ChainRulesCore.jl#449 for a discussion of how to make this more straightforward for all involved.
The text was updated successfully, but these errors were encountered: