Skip to content

Commit

Permalink
feat: extend gradient support for cached nlls
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Nov 1, 2024
1 parent c2438c7 commit 85aa7db
Show file tree
Hide file tree
Showing 8 changed files with 124 additions and 101 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ NLSolvers = "0.5"
NLsolve = "4.5"
NaNMath = "1"
NonlinearProblemLibrary = "0.1.2"
NonlinearSolveBase = "1"
NonlinearSolveBase = "1.2"
NonlinearSolveFirstOrder = "1"
NonlinearSolveQuasiNewton = "1"
NonlinearSolveSpectralMethods = "1"
Expand Down
2 changes: 1 addition & 1 deletion lib/NonlinearSolveBase/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "NonlinearSolveBase"
uuid = "be0214bd-f91f-a760-ac4e-3421ce2b2da0"
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
version = "1.1.0"
version = "1.2.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
98 changes: 8 additions & 90 deletions lib/NonlinearSolveBase/ext/NonlinearSolveBaseForwardDiffExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ using CommonSolve: solve
using DifferentiationInterface: DifferentiationInterface
using FastClosures: @closure
using ForwardDiff: ForwardDiff, Dual
using LinearAlgebra: mul!
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
NonlinearProblem, NonlinearLeastSquaresProblem, remake

Expand All @@ -24,7 +23,10 @@ Utils.value(x::Dual) = ForwardDiff.value(x)
Utils.value(x::AbstractArray{<:Dual}) = Utils.value.(x)

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::Union{IntervalNonlinearProblem, NonlinearProblem, ImmutableNonlinearProblem},
prob::Union{
IntervalNonlinearProblem, NonlinearProblem,
ImmutableNonlinearProblem, NonlinearLeastSquaresProblem
},
alg, args...; kwargs...
)
p = Utils.value(prob.p)
Expand All @@ -35,98 +37,14 @@ function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
end

sol = solve(newprob, alg, args...; kwargs...)

uu = sol.u
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)

if uu isa Number
partials = sum(sumfun, zip(z, pp))
elseif p isa Number
partials = sumfun((z, pp))
else
partials = sum(sumfun, zip(eachcol(z), pp))
end

return sol, partials
end

function NonlinearSolveBase.nonlinearsolve_forwarddiff_solve(
prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...
)
p = Utils.value(prob.p)
newprob = remake(prob; p, u0 = Utils.value(prob.u0))
sol = solve(newprob, alg, args...; kwargs...)
uu = sol.u

# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
J = Utils.safe_similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
# For small problems, nesting ForwardDiff is actually quite fast
autodiff = length(uu) + length(sol.resid) 50 ?
NonlinearSolveBase.select_reverse_mode_autodiff(prob, nothing) :
AutoForwardDiff()

if SciMLBase.isinplace(prob)
vjp_fn = @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
# Using `Constant` lead to dual ordering issues
ff = @closure (du, u) -> prob.f(du, u, p)
resid2 = copy(resid)
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
@. du *= 2
return nothing
end
else
vjp_fn = @closure (u, p) -> begin
v = prob.f(u, p)
# Using `Constant` lead to dual ordering issues
ff = Base.Fix2(prob.f, p)
res = only(DI.pullback(ff, autodiff, u, (v,)))
ArrayInterface.can_setindex(res) || return 2 .* res
@. res *= 2
return res
end
end
end
fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, vjp_fn, uu, newprob.p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, vjp_fn, uu, newprob.p)
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, p)
z = -Jᵤ \ Jₚ
pp = prob.p
sumfun = ((z, p),) -> map(Base.Fix2(*, ForwardDiff.partials(p)), z)
Expand Down
4 changes: 2 additions & 2 deletions lib/NonlinearSolveBase/src/NonlinearSolveBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ using ConcreteStructs: @concrete
using FastClosures: @closure
using Preferences: @load_preference, @set_preferences!

using ADTypes: ADTypes, AbstractADType, AutoSparse, NoSparsityDetector,
using ADTypes: ADTypes, AbstractADType, AutoSparse, AutoForwardDiff, NoSparsityDetector,
KnownJacobianSparsityDetector
using Adapt: WrappedArray
using ArrayInterface: ArrayInterface
Expand All @@ -25,7 +25,7 @@ using SciMLJacobianOperators: JacobianOperator, StatefulJacobianOperator
using SciMLOperators: AbstractSciMLOperator, IdentityOperator
using SymbolicIndexingInterface: SymbolicIndexingInterface

using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind
using LinearAlgebra: LinearAlgebra, Diagonal, norm, ldiv!, diagind, mul!
using Markdown: @doc_str
using Printf: @printf

Expand Down
62 changes: 62 additions & 0 deletions lib/NonlinearSolveBase/src/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,65 @@ end
is_finite_differences_backend(ad::AbstractADType) = false
is_finite_differences_backend(::ADTypes.AutoFiniteDiff) = true
is_finite_differences_backend(::ADTypes.AutoFiniteDifferences) = true

function nlls_generate_vjp_function(prob::NonlinearLeastSquaresProblem, sol, uu)
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
# nested autodiff as the last resort
if SciMLBase.has_vjp(prob.f)
if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f.vjp(resid, u, p)
prob.f.vjp(du, resid, u, p)
du .*= 2
return nothing
end
else
return @closure (u, p) -> begin
resid = prob.f(u, p)
return reshape(2 .* prob.f.vjp(resid, u, p), size(u))
end
end
elseif SciMLBase.has_jac(prob.f)
if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
J = Utils.safe_similar(du, length(sol.resid), length(u))
prob.f.jac(J, u, p)
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
return nothing
end
else
return @closure (u, p) -> begin
return reshape(2 .* vec(prob.f(u, p))' * prob.f.jac(u, p), size(u))
end
end
else
# For small problems, nesting ForwardDiff is actually quite fast
autodiff = length(uu) + length(sol.resid) 50 ?
select_reverse_mode_autodiff(prob, nothing) : AutoForwardDiff()

if SciMLBase.isinplace(prob)
return @closure (du, u, p) -> begin
resid = Utils.safe_similar(du, length(sol.resid))
prob.f(resid, u, p)
# Using `Constant` lead to dual ordering issues
ff = @closure (du, u) -> prob.f(du, u, p)
resid2 = copy(resid)
DI.pullback!(ff, resid2, (du,), autodiff, u, (resid,))
@. du *= 2
return nothing
end
else
return @closure (u, p) -> begin
v = prob.f(u, p)
# Using `Constant` lead to dual ordering issues
res = only(DI.pullback(Base.Fix2(prob.f, p), autodiff, u, (v,)))
ArrayInterface.can_setindex(res) || return 2 .* res
@. res *= 2
return res
end
end
end
end
1 change: 1 addition & 0 deletions lib/NonlinearSolveBase/src/public.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function nonlinearsolve_forwarddiff_solve end
function nonlinearsolve_dual_solution end
function nonlinearsolve_∂f_∂p end
function nonlinearsolve_∂f_∂u end
function nlls_generate_vjp_function end

# Nonlinear Solve Termination Conditions
abstract type AbstractNonlinearTerminationMode end
Expand Down
12 changes: 7 additions & 5 deletions src/forward_diff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,8 @@ function InternalAPI.reinit!(
end

for algType in ALL_SOLVER_TYPES
# XXX: Extend to DualNonlinearLeastSquaresProblem
@eval function SciMLBase.__init(
prob::DualNonlinearProblem, alg::$(algType), args...; kwargs...
prob::DualAbstractNonlinearProblem, alg::$(algType), args...; kwargs...
)
p = nodual_value(prob.p)
newprob = SciMLBase.remake(prob; u0 = nodual_value(prob.u0), p)
Expand All @@ -64,10 +63,13 @@ end
function CommonSolve.solve!(cache::NonlinearSolveForwardDiffCache)
sol = solve!(cache.cache)
prob = cache.prob

uu = sol.u
Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, prob.f, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, prob.f, uu, cache.values_p)

fn = prob isa NonlinearLeastSquaresProblem ?
NonlinearSolveBase.nlls_generate_vjp_function(prob, sol, uu) : prob.f

Jₚ = NonlinearSolveBase.nonlinearsolve_∂f_∂p(prob, fn, uu, cache.values_p)
Jᵤ = NonlinearSolveBase.nonlinearsolve_∂f_∂u(prob, fn, uu, cache.values_p)

z_arr = -Jᵤ \ Jₚ

Expand Down
44 changes: 42 additions & 2 deletions test/forward_ad_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -171,10 +171,50 @@ end
grad1 = ForwardDiff.gradient(solve_nlprob, [34.0, 87.0])
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob, [34.0, 87.0])

@test grad1 grad2 atol = 1e-3
@test grad1grad2 atol=1e-3

hess1 = ForwardDiff.hessian(solve_nlprob, [34.0, 87.0])
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob, [34.0, 87.0])

@test hess1 hess2 atol = 1e-3
@test hess1hess2 atol=1e-3

function solve_nlprob_with_cache(pxpy)
px, py = pxpy
theta1 = pi / 4
theta2 = pi / 4
initial_guess = [theta1; theta2]
l1 = 60
l2 = 60
p = [px; py; l1; l2]
prob = NonlinearLeastSquaresProblem(
NonlinearFunction(objfn, resid_prototype = zeros(2)),
initial_guess, p
)
cache = init(prob; reltol = 1e-12, abstol = 1e-12)
resu = solve!(cache)
th1, th2 = resu.u
cable1_base = [-90; 0; 0]
cable2_base = [-150; 0; 0]
cable3_base = [150; 0; 0]
cable1_top = [l1 * cos(th1) / 2; l1 * sin(th1) / 2; 0]
cable23_top = [l1 * cos(th1) + l2 * cos(th1 + th2) / 2;
l1 * sin(th1) + l2 * sin(th1 + th2) / 2; 0]
c1_length = sqrt((cable1_top[1] - cable1_base[1])^2 +
(cable1_top[2] - cable1_base[2])^2)
c2_length = sqrt((cable23_top[1] - cable2_base[1])^2 +
(cable23_top[2] - cable2_base[2])^2)
c3_length = sqrt((cable23_top[1] - cable3_base[1])^2 +
(cable23_top[2] - cable3_base[2])^2)
return c1_length + c2_length + c3_length
end

grad1 = ForwardDiff.gradient(solve_nlprob_with_cache, [34.0, 87.0])
grad2 = FiniteDiff.finite_difference_gradient(solve_nlprob_with_cache, [34.0, 87.0])

@test grad1grad2 atol=1e-3

hess1 = ForwardDiff.hessian(solve_nlprob_with_cache, [34.0, 87.0])
hess2 = FiniteDiff.finite_difference_hessian(solve_nlprob_with_cache, [34.0, 87.0])

@test hess1hess2 atol=1e-3
end

0 comments on commit 85aa7db

Please sign in to comment.