-
Notifications
You must be signed in to change notification settings - Fork 66
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reverse differentiation through nlsolve #205
Comments
You might want to take a look at https://arxiv.org/abs/1812.01892 . Yes, forward is fast but doesn't scale, and hard-coded adjoints do well. But I think the golden solution might be to just wait for source-to-source like Zygote.jl since then reverse mode can be done without operation tracking. |
Even then, Zygote still has to record all the history of the iterative method and then run it backwards, so that'll likely be slow and memory-consuming, won't it? |
I don't think it has to build a tape to handle loops? But then again, then I don't know how it know how many times to go back through the loop. @MikeInnes |
Well if you backward differentiate through a loop don't you have to keep track of all the intermediary steps? I took at look at your (very nice btw) paper; I think one key difference between adjoints for solving equations and differential equations is that in diff eqs, you are interested in the whole solution, and so don't have any choice but to keep it in memory, at which point reverse-mode differentiating through the thing doesn't look so bad. When solving equations you just want the final solution and discard the iterates. This enables more efficient adjoints for nonlinear solves, where you can just discard the convergence history. To make my point above clearer, take a simpler case: computing the gradient of y -> <b, A^-1 y>, where the A solve is done through a simple iterative method like CG (let's say). If you run reverse AD through CG, you're going to need to store all iterates (or use fancy techniques), and need potentially a lot of memory. Instead, you can just write the gradient as A^-T b and CG-solve that. Obviously this is a trivial example but it generalizes to arbitrary nonlinear systems and outputs. |
I see what you're saying and it totally makes sense in this case. I don't know the right place to overload for this though. We did it directly on Flux.Tracker types. |
This sounds very useful. How about starting from something very simple, like this? using NLsolve
using Zygote
using Zygote: @adjoint, forward
@adjoint nlsolve(f, j, x0; kwargs...) =
let result = nlsolve(f, j, x0; kwargs...)
result, function(vresult)
# This backpropagator returns (- v' (df/dx)⁻¹ (df/dp))'
v = vresult[].zero
x = result.zero
J = j(x)
_, back = forward(f -> f(x), f)
return (back(-(J' \ v))[1], nothing, nothing)
end
end It looks like it's working: julia> d, = gradient(p -> nlsolve(x -> [x[1]^3 - p],
x -> fill(3x[1]^2, (1, 1)),
[1.0]).zero[1],
8.0)
(0.08333333333333333,)
julia> d ≈ 1/3 * 8.0^(1/3 - 1)
true |
Is Zygote ready to be used in the wild? |
Wow, that is so cool. Now we just need JuliaDiff/ChainRulesCore.jl#22 to be able to take a dependency on ChainRulesCore, put that code in nlsolve, add the iterative solve for the zeroth-order methods, and we rule the world! (well, except for mutation...) |
What do you mean? |
@antoine-levitt FYI it looks like complex number interface could become another blocker for Zygote users (see FluxML/Zygote.jl#142 (comment)) although I guess we still can use other AD packages based on ChainRulesCore? But are there other AD packages closer to production-ready than Zygote.jl? I'm wondering if it makes sense to define just for Zygote.jl for now. Reading FluxML/Zygote.jl#291 it seems that ChainRulesCore's API would be close to Zygote.jl so migration doesn't sound hard. Of course, people can just define their own wrapper. I did it already (https://github.com/tkf/SteadyStateFit.jl/blob/239b18252ea5a596b780ddfbeb483b7e80b17572/src/znlsolve.jl) so this is not a blocker for me personally anymore. |
Don't think it's such a blocker : both zygote and chainrules support complex differentials in their full generality, it's just a question of putting the APIs together and of optimization. It looks like zygote and chainrules are going to mesh in the short term, so we might as well wait until then. @oxinabox, does that sound reasonable? I think it's better for nlsolve to take on a dependency on ChainRulesCore than on Zygote. |
Sounds reasonable to me |
Yeah, I don't think it will be much of a blocker.
No later than end of the year. Hopefully much sooner. There is also ZygoteRules.jl which is I think a Zygote specific equiv of ChainRulesCore. |
OK, that's good news. |
Lyndon mentioned it, but just linking ZygoteRules explicitly. RE using Zygote in the wild: the marker for that is really going to be when we release Flux + Zygote; once it's ready for that it's going to have been pretty heavily user-tested. OTOH it's still a relatively large dependency. My suggestion would be to use ZygoteRules to add this adjoint for now, and then switch to ChainRules once it's ready. |
So here's a (very crude) prototype of reverse diffing through a nonlinear PDE solve (based on @tkf's code, but with fully iterative methods, to get something representative of large-scale applications) using NLsolve
using Zygote
using Zygote: @adjoint, forward
using IterativeSolvers
using LinearMaps
using SparseArrays
# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
let result = nlsolve(f, x0; kwargs...)
result, function(vresult)
dx = vresult[].zero
x = result.zero
_, back_x = forward(f, x)
JT(df) = back_x(df)[1]
# solve JT*df = -dx
L = LinearMap(JT, length(x0))
df = gmres(L,-dx)
_, back_f = forward(f -> f(x), f)
return (back_f(df)[1], nothing, nothing)
end
end
const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))
Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@btime gradient(obj, p0)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)) Performance is not great, essentially 20x compared to the analytic version. However profiling shows that this overhead is pretty localized, so it might be possible to optimize it away and get essentially the same perf as the analytic one (this should be a relatively easy case for reverse diff, since there's only vector operations, and no loop). I'm not quite sure what's going here; one possibility is that Zygote tries to diff wrt globally defined constants. |
You could try explicitly dropping gradients of globals to see if that's the issue. |
OK, but how do I do that? |
I see a similar hiccup with closures and would like to know any solution/workaround FluxML/Zygote.jl#323 |
@tkf I see you closed the issue there, but the discussion there was too technical for me to follow. Could you summarize what it means for the code above? Will it be fixed by the chainrules integration? |
@antoine-levitt Short answer is, IIUC, it'll be solved by switching to ChainRulesCore. I posted a longer answer with step-by-step code in FluxML/Zygote.jl#323 (comment) explaining why I thought it was solved. |
If I use LinearAlgebra.diagm in f(x,p), it will raise "Need an adjoint for constructor Pair". How can I write the adjoint method similar to the above adjoint? many thanks! |
I don't know, you probably need to take it up with Zygote (or ChainRules). Also note that the above code was for an older version of Zygote, it needs updating (if anyone does so, please post the result and check whether the above-mentioned slowdown is still present!) |
Here's @antoine-levitt's example as of today using NLsolve
using Zygote
using Zygote: @adjoint
using IterativeSolvers
using LinearMaps
using SparseArrays
using LinearAlgebra
using BenchmarkTools
# nlsolve maps f to the solution x of f(x) = 0
# We have ∂x = -(df/dx)^-1 ∂f, and so the adjoint is df = -(df/dx)^-T dx
@adjoint nlsolve(f, x0; kwargs...) =
let result = nlsolve(f, x0; kwargs...)
result, function(vresult)
dx = vresult[].zero
x = result.zero
_, back_x = Zygote.pullback(f, x)
JT(df) = back_x(df)[1]
# solve JT*df = -dx
L = LinearMap(JT, length(x0))
df = gmres(L,-dx)
_, back_f = Zygote.pullback(f -> f(x), f)
return (back_f(df)[1], nothing, nothing)
end
end
const N = 10000
const nonlin = 0.1
const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
const p0 = randn(N)
f(x, p) = A*x + nonlin*x.^2 - p
solve_x(p) = nlsolve(x -> f(x, p), zeros(N), method=:anderson, m=10).zero
obj(p) = sum(solve_x(p))
Zygote.refresh()
g_auto, = gradient(obj, p0)
g_analytic = gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N))
display(g_auto)
display(g_analytic)
@show sum(abs.(g_auto - g_analytic))
@btime gradient(obj, p0);
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); My local timings:
|
Thanks @niklasschmitz for providing the updated code. |
It now seems to me that the big slowdown is caused by the sparse matrix - const A = spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))
+ const A = Array(spdiagm(0 => fill(10.0, N), 1 => fill(-1.0, N-1), -1 => fill(-1.0, N-1))) # try dense A, for comparison only For this I get the following timings: @btime gradient(obj, p0); # 26.382 ms (624 allocations: 63.30 MiB)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 16.002 ms (446 allocations: 9.52 MiB) So the previous 100x relative slowdown seems gone, cc @antoine-levitt @rkube The large performance penalties when using |
Could also be that the cost of matvecs with dense matrices is much larger than the sparse ones so that it hides the overhead? |
I now tried to double-check by trying the sparse case again but with a custom rrule for the inner function using ChainRulesCore
function ChainRulesCore.rrule(::typeof(f), x, p)
y = f(x, p)
function f_pullback(ȳ)
∂x = @thunk(A'ȳ + 2nonlin*x.*ȳ)
∂p = @thunk(-ȳ)
return (NO_FIELDS, ∂x, ∂p)
end
return y, f_pullback
end
Zygote.refresh() Going back to the original example problem from above (i.e. @btime gradient(obj, p0); # 22.756 ms (986 allocations: 23.99 MiB)
@btime gmres((A + Diagonal(2*nonlin*solve_x(p0)))', ones(N)); # 23.065 ms (786 allocations: 21.23 MiB) |
JuliaDiff/ChainRulesCore.jl#363 is required to avoid the Zygote dependency. |
https://github.com/SciML/DiffEqSensitivity.jl/blob/master/src/steadystate_adjoint.jl#L2-L81 is an implementation in DiffEqSensitivity. It has a ton of options for the the vjp can be calculated, https://diffeq.sciml.ai/stable/analysis/sensitivity/#Internal-Automatic-Differentiation-Options-(ADKwargs), but that should get replaced by AbstractDifferentiation.jl. See JuliaDiff/AbstractDifferentiation.jl#1 And @YingboMa did one for NonlinearSolve.jl: https://gist.github.com/YingboMa/4e4496f828c6a3179004f6d0ca224d2a |
Thanks Chris for the pointers! function ChainRulesCore.rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(nlsolve), f, x0; kwargs...)
result = nlsolve(f, x0; kwargs...)
function nlsolve_pullback(Δresult)
Δx = Δresult[].zero
x = result.zero
_, f_pullback = rrule_via_ad(config, f, x)
JT(v) = f_pullback(v)[2] # w.r.t. x
# solve JT*Δfx = -Δx
L = LinearMap(JT, length(x0))
Δfx = gmres(L, -Δx)
∂f = f_pullback(Δfx)[1] # w.r.t. f itself (implicitly closed-over variables)
return (NoTangent(), ∂f, ZeroTangent())
end
return result, nlsolve_pullback
end Full gist is here: https://gist.github.com/niklasschmitz/b00223b9e9ba2a37ed09539a264bf423 |
Yay! So, @pkofod it seems everything is in place for putting that rule into nlsolve. That means taking on a dependency on ChainRulesCore and IterativeSolvers, are you OK with that? Also should we do it in the new nlsolvers or not? |
And can we add a keyword argument for dispatching, i.e. |
This should work
This is harder. One way that could be done for this particular case is having Other possibilities are to write the dispatch on the kwargs, which I think is possible with sufficient evil? |
The way we do this in SciML is to make a drop method: https://github.com/SciML/DiffEqBase.jl/blob/v6.64.0/src/solve.jl#L66-L71 and then define the adjoint dispatch on a given set of types, making the rrule undefined on the https://github.com/SciML/DiffEqBase.jl/blob/v6.64.0/src/solve.jl#L297-L302 by making it not part of the abstract type. |
Yes of course, the linear solver should follow the nonlinear one. Anderson and Broyden should go to GMRES (which is the linearized version of anderson anyway), and the Newton solvers should go to whatever is used for solving the update equations. Another possibility is to just call nlsolve recursively on the linear equation. |
We can also make this optional, either by having the user call explicitly an |
I don't think it needs to be separate, just switchable.
That would make a lot of sense and naturally make it use the same linear solver. |
Yeah that sounds like a sensible default, although a bit suboptimal (GMRES is more stable than anderson, the newton method might not be able to figure out it doesn't need any stabilization, etc). |
Well the default should probably still be GMRES on the fixed point methods, and only do the recursive iteration on Newton. That would likely get pretty close to optimal, since someone would only choose Newton if they need it (in theory). |
Cool, easy, then we can do this. We can also make methods that are only defined if you have access to a forwards mode AD. |
That's what we call https://diffeq.sciml.ai/stable/analysis/sensitivity/#Sensitivity-Algorithms |
yep, and we can make those work without a direct dependency on ForwardDiff. |
hello! just wondering about the progress on this ticket. |
#205 (comment) works, and you can fine tune it to suit your needs. Putting this into the actual nlsolve requires more thought about API, solvers, tolerances, defaults, etc. |
BTW, someone should check if Zygote.jl just fails currently on NLsolve.jl. If it does, then you might as well take the working adjoint and slap it on there and do a quick merge. It's not numerically robust, but it's better than failing and step 1 to making something better. |
Just checking in again to say that I'm not really sure how these things would be implemented and used in practice, but if @oxinabox can help or at least hold my hand I'm happy to include this feature. |
I am happy to hold your hand though this. |
Yeah, the snippet above should be fine, with the caveat that the linear solve should be replaced by a recursive nlsolve call. |
I updated the gist example at https://gist.github.com/niklasschmitz/b00223b9e9ba2a37ed09539a264bf423#gistcomment-3830191 |
Note that this example is adapted for the matrix-free case (where the jacobian is not computed explicitly). In the case where an explicit jacobian is provided, it should be used as JT (instead of computed through AD), and the objective function v -> J^T v + delta x should be passed the explicit jacobian. @pkofod can you help with the API here? I'm a bit fuzzy on how to get that information with the nlsolversbase wrapper types & co. |
OK, so this is pretty speculative as the reverse differentiation packages are not there yet, but let's dream for a moment. It would be awesome to be able to just use reverse-mode differentiation on code like
and take the gradient of G wrt α. Of course, both
F
andH
are examples, and can be arbitrary functions.So how to get the gradient of G? One can of course forward diff through G, which is not too hard to support from the perspective of nlsolve (although I haven't tried). But that's pretty inefficient if α is high-dimensional. One can try reverse-diffing through G, but that's pretty heavy since this has to basically record all the iterations. A better idea is to exploit the mathematical structure of the problem, and in particular the relationship dx/dα = -(∂F/∂x)^-1 ∂F/∂α (differentiate F(x(α),α)=0 wrt α), assuming nlsolve is converged perfectly. Reverse-mode autodiff requires the user to compute (dx/dα)^T δx, which is -(∂F/∂α)^T (∂F/∂x)^-T δx. If the jacobian is not provided (Broyden or Anderson), this can be done by using an iterative solver such as GMRES, and where the individual matvecs with (∂F/∂x)^T are performed with reverse diff.
The action point for nlsolve here is to write a reverse ChainRule (https://github.com/JuliaDiff/ChainRules.jl) for nlsolve. This might be tricky because nlsolve takes a function as argument, but we might get by with just calling a diff function on F recursively. CC @jrevels to check this isn't a completely stupid idea. Of course, this isn't necessarily specific to nlsolve; the same ideas apply to optim (writing ∇F = 0) and diffeq (adjoint equations) for instance.
The text was updated successfully, but these errors were encountered: