-
Notifications
You must be signed in to change notification settings - Fork 96
Description
TL/DR a better approach for rules for matrix exponential of dense matrices.
We should add rules for exp(::StridedMatrix) would supersede Zygote's.
The right rule for exp would be none, i.e. to AD through exp, which uses the scaling and squaring algorithm, but this uses mutation, which Zygote doesn't support. it also is limited to BlasFloat, so ForwardDiff and I don't think any of the operator overloading AD's can handle it. All that to say, we should have the rules.
Zygote currently uses the power series pullback using eigendecomposition for the backward pass. The eigendecomposition is not an accurate way to compute the exponential in general (https://epubs.siam.org/doi/10.1137/S00361445024180) (it's fine for hermitian matrices though, hence the exp(::Hermitian) overload in LinearAlgebra). Zygote's adjoint uses exp for the primal, so it only introduces potential inaccuracy in the pullback though. However, it doesn't follow the same time complexity as the primal function, and it is quite wasteful.
EDIT: everything said below is still valid, but it is a general property of power series matrix functions with real coefficients that the pullback is the pushforward pre- and post- composed with adjoint, or, equivalently, the pushforward of the function applied to the adjoint of the primal. i.e. if Y = f(A), then we have the equality
(f^*)_{Y} (ΔY) = (f_*)_{A'} (ΔY)This applies to exp, log, and all trigonometric and hyperbolic functions.
Y = exp(A) appears in the solution to an ODE. We can augment the ODE to get a new one whose solution also uses the matrix exponential and which gives us the pushforward (discussed in section 7 of https://www2.humusoft.cz/www/papers/tcp08/017_brancik.pdf, though the result is older and can be worked out from https://ieeexplore.ieee.org/document/1101743 or witty algebra).
In short, given B = [A ΔA; zero(A) A], then exp(B) = [Y ∂Y; zero(Y) Y]
I didn't find a reference for this, but we can do the same thing for the pullback by constructing and solving the adjoint ODE This is easy to verify:
Given B = [A ΔY'; zero(A) A], then exp(B) = [Y ∂A'; zero(Y) Y]
That is, the pullback exp^* is related to the pushforward exp_{*} by exp^* = adjoint ∘ exp_{*} ∘ adjoint.
julia> using LinearAlgebra, FiniteDifferences
julia> A, Δ = randn(ComplexF64, 30, 30), randn(ComplexF64, 30, 30);
julia> only(j′vp(central_fdm(5, 1), exp, Δ, A)) ≈ jvp(central_fdm(5, 1), exp, (A, Δ'))'
trueThe problem with the augmented matrix approach is that it is 8x the cost of the primal, when we should be able to get <5x. For small matrices (<100x100) this is faster than the eigendecomposition approach, and it should be more accurate, but for large dense matrices, the eigendecomposition approach is faster.~
But the relationship between the pushforward and pullback motivates a solution. Namely, explicitly implement the pushforward of the scaling and squaring approach used by LinearAlgebra.exp!. Not only do we get the pushforward with the same time complexity as the primal, but we can then compute the pullback with the same time complexity of the primal without the need to checkpoint any of the intermediate matrices, and with mutation allowed.
I'm planning to tackle this after some of my other open PRs are wrapped up, but I wanted to get it in writing while it was fresh on my mind.