-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add rules for the matrix exponential #331
Comments
Thanks for sharing your nice observation! |
No problem! I have updated with a more general statement
One of two ways:
I'll probably do (1). |
Thanks. If I understand corretly, this PR in Jax switched from (1) to (2) for the JVP of expm: jax-ml/jax#4314 and this one jax-ml/jax#4331 implements expm_frechet using jvp(expm), that is implementing (1) using (2). |
Yeah, the way jax does it is definitely better. Here we're limited in a sense by the fact that this package provides rules for all ChainRules-compatible ADs. We currently don't have a way to embed an automatically differentiated function within a custom rule, though that is planned (JuliaDiff/ChainRulesCore.jl#68). But even that wouldn't help here, where we want to compose a pushforward and a pullback, which will in general be provided by different AD packages. Part of the cost of being general. |
Thanks! I believe this is also related https://github.com/Lezcano/expRNN/blob/830ec836521d0c295436dcafc3f0b3deea36c83c/trivializations.py#L19 Though, I am not sure why there is only one transpose... |
JuliaDiff/ChainRulesCore.jl#68 should be able to do that. |
TL/DR a better approach for rules for matrix exponential of dense matrices.
We should add rules for
exp(::StridedMatrix)
would supersede Zygote's.The right rule for
exp
would be none, i.e. to AD throughexp
, which uses the scaling and squaring algorithm, but this uses mutation, which Zygote doesn't support. it also is limited toBlasFloat
, so ForwardDiff and I don't think any of the operator overloading AD's can handle it. All that to say, we should have the rules.Zygote currently uses the power series pullback using eigendecomposition for the backward pass. The eigendecomposition is not an accurate way to compute the exponential in general (https://epubs.siam.org/doi/10.1137/S00361445024180) (it's fine for hermitian matrices though, hence the
exp(::Hermitian)
overload in LinearAlgebra). Zygote's adjoint usesexp
for the primal, so it only introduces potential inaccuracy in the pullback though. However, it doesn't follow the same time complexity as the primal function, and it is quite wasteful.EDIT: everything said below is still valid, but it is a general property of power series matrix functions with real coefficients that the pullback is the pushforward pre- and post- composed with
adjoint
, or, equivalently, the pushforward of the function applied to the adjoint of the primal. i.e. ifY = f(A)
, then we have the equalityThis applies to
exp
,log
, and all trigonometric and hyperbolic functions.Y = exp(A)
appears in the solution to an ODE. We can augment the ODE to get a new one whose solution also uses the matrix exponential and which gives us the pushforward (discussed in section 7 of https://www2.humusoft.cz/www/papers/tcp08/017_brancik.pdf, though the result is older and can be worked out from https://ieeexplore.ieee.org/document/1101743 or witty algebra).In short, given
B = [A ΔA; zero(A) A]
, thenexp(B) = [Y ∂Y; zero(Y) Y]
I didn't find a reference for this, but we can do the same thing for the pullback by constructing and solving the adjoint ODEThis is easy to verify:Given
B = [A ΔY'; zero(A) A]
, thenexp(B) = [Y ∂A'; zero(Y) Y]
That is, the pullback
exp^*
is related to the pushforwardexp_{*}
byexp^* = adjoint ∘ exp_{*} ∘ adjoint
.The problem with the augmented matrix approach is that it is 8x the cost of the primal, when we should be able to get <5x. For small matrices (<100x100) this is faster than the eigendecomposition approach, and it should be more accurate, but for large dense matrices, the eigendecomposition approach is faster.~But the relationship between the pushforward and pullback motivates a solution. Namely, explicitly implement the pushforward of the scaling and squaring approach used by
LinearAlgebra.exp!
. Not only do we get the pushforward with the same time complexity as the primal, but we can then compute the pullback with the same time complexity of the primal without the need to checkpoint any of the intermediate matrices, and with mutation allowed.I'm planning to tackle this after some of my other open PRs are wrapped up, but I wanted to get it in writing while it was fresh on my mind.
The text was updated successfully, but these errors were encountered: