Skip to content

Add rules for the matrix exponential #331

@sethaxen

Description

@sethaxen

TL/DR a better approach for rules for matrix exponential of dense matrices.

We should add rules for exp(::StridedMatrix) would supersede Zygote's.

The right rule for exp would be none, i.e. to AD through exp, which uses the scaling and squaring algorithm, but this uses mutation, which Zygote doesn't support. it also is limited to BlasFloat, so ForwardDiff and I don't think any of the operator overloading AD's can handle it. All that to say, we should have the rules.

Zygote currently uses the power series pullback using eigendecomposition for the backward pass. The eigendecomposition is not an accurate way to compute the exponential in general (https://epubs.siam.org/doi/10.1137/S00361445024180) (it's fine for hermitian matrices though, hence the exp(::Hermitian) overload in LinearAlgebra). Zygote's adjoint uses exp for the primal, so it only introduces potential inaccuracy in the pullback though. However, it doesn't follow the same time complexity as the primal function, and it is quite wasteful.

EDIT: everything said below is still valid, but it is a general property of power series matrix functions with real coefficients that the pullback is the pushforward pre- and post- composed with adjoint, or, equivalently, the pushforward of the function applied to the adjoint of the primal. i.e. if Y = f(A), then we have the equality

(f^*)_{Y} (ΔY) = (f_*)_{A'} (ΔY)

This applies to exp, log, and all trigonometric and hyperbolic functions.

Y = exp(A) appears in the solution to an ODE. We can augment the ODE to get a new one whose solution also uses the matrix exponential and which gives us the pushforward (discussed in section 7 of https://www2.humusoft.cz/www/papers/tcp08/017_brancik.pdf, though the result is older and can be worked out from https://ieeexplore.ieee.org/document/1101743 or witty algebra).
In short, given B = [A ΔA; zero(A) A], then exp(B) = [Y ∂Y; zero(Y) Y]

I didn't find a reference for this, but we can do the same thing for the pullback by constructing and solving the adjoint ODE
Given B = [A ΔY'; zero(A) A], then exp(B) = [Y ∂A'; zero(Y) Y]
That is, the pullback exp^* is related to the pushforward exp_{*} by exp^* = adjoint ∘ exp_{*} ∘ adjoint.
This is easy to verify:

julia> using LinearAlgebra, FiniteDifferences

julia> A, Δ = randn(ComplexF64, 30, 30), randn(ComplexF64, 30, 30);

julia> only(j′vp(central_fdm(5, 1), exp, Δ, A))  jvp(central_fdm(5, 1), exp, (A, Δ'))'
true

The problem with the augmented matrix approach is that it is 8x the cost of the primal, when we should be able to get <5x. For small matrices (<100x100) this is faster than the eigendecomposition approach, and it should be more accurate, but for large dense matrices, the eigendecomposition approach is faster.~

But the relationship between the pushforward and pullback motivates a solution. Namely, explicitly implement the pushforward of the scaling and squaring approach used by LinearAlgebra.exp!. Not only do we get the pushforward with the same time complexity as the primal, but we can then compute the pullback with the same time complexity of the primal without the need to checkpoint any of the intermediate matrices, and with mutation allowed.

I'm planning to tackle this after some of my other open PRs are wrapped up, but I wanted to get it in writing while it was fresh on my mind.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions