Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add rrules for binary linear algebra operations #29

Merged
merged 1 commit into from
Jun 13, 2019
Merged

Conversation

ararslan
Copy link
Member

@ararslan ararslan commented May 1, 2019

These are ported from Nabla.

WIP because:

@ararslan
Copy link
Member Author

ararslan commented May 2, 2019

This test failure is particularly interesting:

*(::Transpose, ::AbstractArray): Test Failed at /home/travis/build/JuliaDiff/ChainRules.jl/test/test_util.jl:63
  Expression: isapprox(Δx_ad, Δx_fd; rtol=rtol, atol=atol, kwargs...)
   Evaluated: isapprox([-1.17704 6.70007 … 2.30549 -4.44818; -0.861323 2.1737 … -0.446239 -0.175601; … ; 0.838357 -0.806677 … 0.963402 -0.61938; 0.818415 0.918045 … 1.84326 -0.308386], [-1.17704 -0.861323 … 0.838357 0.818415; 6.70007 2.1737 … -0.806677 0.918045; … ; 2.30549 -0.446239 … 0.963402 1.84326; -4.44818 -0.175601 … -0.61938 -0.308386]; rtol=1.0e-9, atol=1.0e-9)

The result computed by the rrule and by finite differencing are transposes of each other. It's not clear to me whether that actually signifies that there might be an issue with the transpose stuff I added in FDM, since the stuff I've added in this PR are the exact same definitions we use in Nabla.

@nickrobinson251
Copy link
Contributor

oof, test failure looks real? Either FDM or Nabla have an incorrect definition?

@ararslan
Copy link
Member Author

ararslan commented May 3, 2019

Yeah test failures are definitely real. The problem is I don't know which part is wrong. 😬

@ararslan
Copy link
Member Author

ararslan commented May 3, 2019

@willtebbutt, would you be able to provide any guidance here?

@ararslan
Copy link
Member Author

ararslan commented May 7, 2019

So Nabla's definitions do indeed seem to be correct, even though we aren't properly testing them in Nabla itself. After speaking a bit with Will, our current thinking is that ChainRules' test framework integration with FDM is doing something weird with transpose wrappers.

@ararslan
Copy link
Member Author

After discussing with Will, Wessel, and Lyndon:

So Nabla's definitions do indeed seem to be correct

They are not.

even though we aren't properly testing them in Nabla itself

This is in fact still the case, even after I added tests I thought exercised those methods.

our current thinking is that ChainRules' test framework integration with FDM is doing something weird with transpose wrappers.

This is no longer our current thinking. The sensitivities are wrong and should be removed from both Nabla and this PR.

Copy link
Member

@willtebbutt willtebbutt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now that everything passes.

src/rules/linalg/dense.jl Outdated Show resolved Hide resolved
@ararslan
Copy link
Member Author

ararslan commented Jun 7, 2019

Adding tests for rectangular matrices revealed an issue with the sensitivity definitions themselves. See invenia/Nabla.jl#175. I'll update this PR once a solution has been found there.

@willtebbutt
Copy link
Member

Please see my PR to Zygote relating to this. Should be a straightforward copy-paste to this PR.

@ararslan
Copy link
Member Author

Things aren't so straightforward, unfortunately: ChainRules can't deal with the omission of a rule for / even though it's implemented internally in LinearAlgebra in terms of \. My own attempts at devising a correct implementation of that in terms of transposes (using a / b = (b' \ a')') have failed, so any further guidance on that would be great.

@willtebbutt
Copy link
Member

ChainRules can't deal with the omission of a rule for / even though it's implemented internally in LinearAlgebra in terms of .

My original thinking on this is that ChainRules shouldn't really have to worry about this -- it can just assume that whichever AD package uses it can successfully backprop through stuff. I suppose that this isn't the case for Nabla though, so we should probably implement this by hand.

What have you tried so far for the hand implementation?

@willtebbutt willtebbutt reopened this Jun 11, 2019
@willtebbutt
Copy link
Member

Sorry, accidental close of PR there...

@ararslan
Copy link
Member Author

What have you tried so far for the hand implementation?

Composing rules for / and ', simply transposing the result, multiplying the incoming sensitivity by the transpose of the rules for B' \ A'... Throwing math at the wall and seeing what sticks, and so far nothing has. 😬

@willtebbutt
Copy link
Member

willtebbutt commented Jun 11, 2019

So the programme to compute A / B in terms of \ is something like

A_ = adjoint(A)
B_ = adjoint(B)
C_ = A_ \ B_
C = adjoint(C_)

so the adjoint programme should be something like

function rrule(::typeof(/), A::AbstractMatrix{<:Real}, B::AbstractMatrix{<:Real})
    A_, A_back = rrule(adjoint, A)
    B_, B_back = rrule(adjoint, B)
    C_, C_back = rrule(\, A_, B_)
    C, Cback = rrule(adjoint, C_)
    return C, function(Cadj)
        # call the reverse-rules in reverse order
    end
end

You could implement this without using the rrules for adjoint, but it might be safer to use them.

@ararslan
Copy link
Member Author

What I came up with is

function rrule(::typeof(/), A::AbstractVecOrMat{<:Real}, B::AbstractVecOrMat{<:Real})
    Aᵀ, dA = rrule(adjoint, A)
    Bᵀ, dB = rrule(adjoint, B)
    Cᵀ, (dBᵀ, dAᵀ) = rrule(\, Bᵀ, Aᵀ)
    C, dC = rrule(adjoint, Cᵀ)
    ∂A = Rule(dCdAᵀdA)
    ∂B = Rule(dCdBᵀdB)
    return C, (∂A, ∂B)
end

which works for matrices but not for vectors...

@ararslan
Copy link
Member Author

Any idea what to do for vectors, @willtebbutt?

@willtebbutt
Copy link
Member

What sort of error are you getting for them?

@ararslan
Copy link
Member Author

Never mind, I permuted dA/dB and dC in the above function compositions.

@ararslan
Copy link
Member Author

Okay, I'm almost there. The outstanding issue is that Diagonal \ Vector returns a dense matrix with nonzero off diagonals and FDM returns a Diagonal. The diagonal elements are equal, but I'm not sure which one is right in this case. 🤔

@willtebbutt
Copy link
Member

The outstanding issue is that Diagonal \ Vector returns a dense matrix with nonzero off diagonals and FDM returns a Diagonal. The diagonal elements are equal, but I'm not sure which one is right in this case. thinking

FDM is doing the right thing -- we definitely want the adjoint to be represented as a Diagonal matrix to ensure that when we're dealing with diagonal matrices, the whole thing has O(N) time / memory complexity.

@ararslan ararslan marked this pull request as ready for review June 13, 2019 19:27
@ararslan ararslan changed the title WIP: Add rrules for binary linear algebra operations Add rrules for binary linear algebra operations Jun 13, 2019
@ararslan
Copy link
Member Author

Alright, things should be in order now.

@willtebbutt
Copy link
Member

This looks great. Would you mind adding a few tests to ensure that the sensitivity of a Diagonal matrix involved in *, \ or / is itself Diagonal?

@ararslan
Copy link
Member Author

I don't understand why it would be Diagonal in all cases. For example, if X is dense and D is diagonal, X * D is dense, and the corresponding sensitivity is Ȳ * D', which is similarly dense.

@willtebbutt
Copy link
Member

willtebbutt commented Jun 13, 2019

Right, yeah, so this is something I've thought about a fair bit. There are two situations

  1. We have a programme that constructs a Diagonal matrix at some point in it. e.g.
function(a)
    D = Diagonal(a)
    return foo(D)
end

In this case, it's totally fine to drop the off-diagonal bits of because we know for a fact that they will never be required by virtue of the manner in which D was constructed. That is, we can safely drop-the off-diagonal bits of since is unaffected by them.

  1. D is an argument to the function in question, e.g.:
D::Diagonal -> foo(D)

In this case, in some sense we need to think more about how the user is planning to use (clearly this is a bit more hand-wavy). The position I've always adopted is that if the user had intended to do something with the gradient w.r.t. the off-diagonal elements, they would have provided a dense matrix, not a Diagonal one.

As a consequence of the above, I've always felt justified employing this type of optimisation. Frankly, if you don't, there's essentially no point at all in using structured matrices in code whose adjoint you plan to compute because you lose the structure on the reverse-pass and revert to whatever complexity is associated with the dense version of your programme.

@ararslan
Copy link
Member Author

That makes sense I suppose, though

clearly this is a bit more hand-wavy

makes me nervous... Nabla also does not appear to do things that way.

@willtebbutt
Copy link
Member

Nabla also does not appear to do things that way.

Which way is that sorry?

@ararslan
Copy link
Member Author

Sorry, wasn't clear. For example, in the case of multiplication with diagonal matrices, Nabla produces a dense matrix as the sensitivity for the diagonal argument.

@willtebbutt
Copy link
Member

Ah okay. I would view that as a bug.

FWIW, the other thing to think about is what is actually happening computationally under the hood. Ultimately the Diagonal matrix type doesn't use any off-diagonal elements when used in e.g. a matrix-matrix multiply - the Diagonal type simply doesn't allow you to have non-zero off-diagonal elements, so it's a slightly odd question to ask what happens if you perturb the off-diagonals by an infinitesimal amount (i.e. compute the gradient w.r.t. them).

It's this slightly weird situation in which thinking about a Diagonal matrix as a regular dense matrix that happens to contain zeros on its off-diagonals isn't really faithful to the semantics of the type (not sure if I've really phrased that correctly, but hopefully the gist is clear)

@ararslan
Copy link
Member Author

I see, that explanation makes sense to me. I think that will require a larger effort to audit everything in here (and Nabla) to ensure that's what it's doing, as many things are typed as AbstractMatrix without considering return types for the rules.

@willtebbutt
Copy link
Member

Yeah, I completely agree. Fortunately this kind of bug doesn't make things wrong, it just makes them slow, so it's not the end of the world most of the time.

@ararslan
Copy link
Member Author

Since nothing is actually using ChainRules right now, can we go ahead and merge this to master, then on the plane tomorrow I'll look into a larger refactor to ensure consistent matrix types? Easier if I just do everything together on one branch.

@willtebbutt
Copy link
Member

Please feel free to merge :)

@ararslan
Copy link
Member Author

Thanks for the review and explanations!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants