-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rrules for binary linear algebra operations #29
Conversation
This test failure is particularly interesting:
The result computed by the |
oof, test failure looks real? Either FDM or Nabla have an incorrect definition? |
Yeah test failures are definitely real. The problem is I don't know which part is wrong. 😬 |
@willtebbutt, would you be able to provide any guidance here? |
So Nabla's definitions do indeed seem to be correct, even though we aren't properly testing them in Nabla itself. After speaking a bit with Will, our current thinking is that ChainRules' test framework integration with FDM is doing something weird with transpose wrappers. |
After discussing with Will, Wessel, and Lyndon:
They are not.
This is in fact still the case, even after I added tests I thought exercised those methods.
This is no longer our current thinking. The sensitivities are wrong and should be removed from both Nabla and this PR. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM now that everything passes.
Adding tests for rectangular matrices revealed an issue with the sensitivity definitions themselves. See invenia/Nabla.jl#175. I'll update this PR once a solution has been found there. |
Please see my PR to Zygote relating to this. Should be a straightforward copy-paste to this PR. |
98c9999
to
ffce382
Compare
Things aren't so straightforward, unfortunately: ChainRules can't deal with the omission of a rule for |
My original thinking on this is that ChainRules shouldn't really have to worry about this -- it can just assume that whichever AD package uses it can successfully backprop through stuff. I suppose that this isn't the case for Nabla though, so we should probably implement this by hand. What have you tried so far for the hand implementation? |
Sorry, accidental close of PR there... |
Composing rules for |
So the programme to compute A_ = adjoint(A)
B_ = adjoint(B)
C_ = A_ \ B_
C = adjoint(C_) so the adjoint programme should be something like function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
A_, A_back = rrule(adjoint, A)
B_, B_back = rrule(adjoint, B)
C_, C_back = rrule(\, A_, B_)
C, Cback = rrule(adjoint, C_)
return C, function(Cadj)
# call the reverse-rules in reverse order
end
end You could implement this without using the |
What I came up with is function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
Aᵀ, dA = rrule(adjoint, A)
Bᵀ, dB = rrule(adjoint, B)
Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ)
C, dC = rrule(adjoint, Cᵀ)
∂A = Rule(dC∘dAᵀ∘dA)
∂B = Rule(dC∘dBᵀ∘dB)
return C, (∂A, ∂B)
end which works for matrices but not for vectors... |
Any idea what to do for vectors, @willtebbutt? |
What sort of error are you getting for them? |
Never mind, I permuted |
Okay, I'm almost there. The outstanding issue is that |
FDM is doing the right thing -- we definitely want the adjoint to be represented as a |
ffce382
to
5202be2
Compare
Alright, things should be in order now. |
This looks great. Would you mind adding a few tests to ensure that the sensitivity of a |
I don't understand why it would be |
Right, yeah, so this is something I've thought about a fair bit. There are two situations
function(a)
D = Diagonal(a)
return foo(D)
end In this case, it's totally fine to drop the off-diagonal bits of
D::Diagonal -> foo(D) In this case, in some sense we need to think more about how the user is planning to use As a consequence of the above, I've always felt justified employing this type of optimisation. Frankly, if you don't, there's essentially no point at all in using structured matrices in code whose adjoint you plan to compute because you lose the structure on the reverse-pass and revert to whatever complexity is associated with the dense version of your programme. |
That makes sense I suppose, though
makes me nervous... Nabla also does not appear to do things that way. |
Which way is that sorry? |
Sorry, wasn't clear. For example, in the case of multiplication with diagonal matrices, Nabla produces a dense matrix as the sensitivity for the diagonal argument. |
Ah okay. I would view that as a bug. FWIW, the other thing to think about is what is actually happening computationally under the hood. Ultimately the It's this slightly weird situation in which thinking about a |
I see, that explanation makes sense to me. I think that will require a larger effort to audit everything in here (and Nabla) to ensure that's what it's doing, as many things are typed as |
Yeah, I completely agree. Fortunately this kind of bug doesn't make things wrong, it just makes them slow, so it's not the end of the world most of the time. |
Since nothing is actually using ChainRules right now, can we go ahead and merge this to master, then on the plane tomorrow I'll look into a larger refactor to ensure consistent matrix types? Easier if I just do everything together on one branch. |
Please feel free to merge :) |
Thanks for the review and explanations! |
These are ported from Nabla.
WIP because:
AbstractArray
andTranspose
/Adjoint
cases don't passTranspose
/Adjoint
in FDM requires the currently untagged 0.5.0 (see New version FDM: 0.5.0 JuliaRegistries/General#487)