-
Notifications
You must be signed in to change notification settings - Fork 15
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cannot test rule for structure where one (non-differentiable) field cannot be vectorized #256
Comments
to_vec
is applied recursively to entire structure, even when struct is only partially differentiable
Did you try |
I just did, and unfortunately that fails too. Looking at the code,
it seems like the only influence of the tangent on the finite difference call is whether to ignore the full input or not? (This code is called from test_frule , in which I also get an error)
|
I agree this is the ideal behaviour. CRTU is not perfect, you can be almost certain that your rule is correct if tests pass, but tests failing could mean either an issue with the rule or an issue with CRTU. Slightly less ideal solution, but quicker would be to try defining a |
Any pointers on doing the necessary type piracy? For |
Sure, I would try with something like function FiniteDifferences.to_vec(x::FFTPlan) # or whatever the type of field x is
function FFTPlan_from_vec(x_vec::Vector)
return x
end
return Bool[], FFTPlan_from_vec
end This should work for both the |
For posterity, here's what I needed to do to get function FiniteDifferences.to_vec(x::InnerPlan)
function FFTPlan_from_vec(x_vec::Vector)
return x
end
return Bool[], FFTPlan_from_vec
end
ChainRulesTestUtils.test_approx(::ChainRulesCore.AbstractZero, x::InnerPlan, msg=""; kwargs...) = true
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::InnerPlan) = ChainRulesCore.NoTangent() Note that I needed the |
That is correct for that. And in the long term we want to stop using it entirely. |
I am writing an
rrule
for*(::Struct, arr)
where the structure and function look like,However, I am unable to test my
rrule
, even with a manually provided tangent for an instance ofStruct
that looks likeTangent{Struct}(;x=NoTangent(), y=one(T))
.The reason seems to be that finite differences tries to
to_vec
the instance of struct. Given that the struct is not completely ignored, only the fieldx
, it ends up trying toto_vec
the fieldx
as well. But this field is a reference to a rather crazy mutable structure with circular references, and so I end up with an error, and am unable to test that therrule
is correct w.r.t.y
.(To make it more concrete, the structure in question is a
ScaledPlan
fromAbstractFFTs
and the fieldx
refers to a primitive FFT plan fromFFTW
, which is mutable because of itspinv
cache.)Ideally, it shouldn't matter what value is in the field
x
, since it is marked asNoTangent
in the user-provided tangent? Just as how the entire input gets ignored if it is marked asNoTangent
.The text was updated successfully, but these errors were encountered: