Skip to content
This repository was archived by the owner on May 15, 2025. It is now read-only.

Commit 0f03f75

Browse files
committed
Clean up AD
1 parent 24f83d8 commit 0f03f75

File tree

5 files changed

+67
-77
lines changed

5 files changed

+67
-77
lines changed

src/SimpleNonlinearSolve.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,10 @@ using PrecompileTools: @compile_workload, @setup_workload, @recompile_invalidati
2020
mul!, norm, transpose
2121
using MaybeInplace: @bb, setindex_trait, CanSetindex, CannotSetindex
2222
using Reexport: @reexport
23-
using SciMLBase: SciMLBase, IntervalNonlinearProblem, NonlinearFunction,
24-
NonlinearLeastSquaresProblem, NonlinearProblem, ReturnCode, init,
25-
remake, solve, AbstractNonlinearAlgorithm, build_solution, isinplace,
26-
_unwrap_val
23+
using SciMLBase: SciMLBase, AbstractNonlinearProblem, IntervalNonlinearProblem,
24+
NonlinearFunction, NonlinearLeastSquaresProblem, NonlinearProblem,
25+
ReturnCode, init, remake, solve, AbstractNonlinearAlgorithm,
26+
build_solution, isinplace, _unwrap_val
2727
using StaticArraysCore: StaticArray, SVector, SMatrix, SArray, MArray, Size
2828
end
2929

src/ad.jl

Lines changed: 34 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,15 @@
1-
function SciMLBase.solve(
2-
prob::NonlinearProblem{<:Union{Number, <:AbstractArray}, iip,
3-
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
4-
alg::AbstractSimpleNonlinearSolveAlgorithm,
5-
args...;
6-
kwargs...) where {T, V, P, iip}
7-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
8-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
9-
return SciMLBase.build_solution(
10-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
11-
end
12-
13-
function SciMLBase.solve(
14-
prob::NonlinearLeastSquaresProblem{
15-
<:AbstractArray, iip, <:Union{<:AbstractArray{<:Dual{T, V, P}}}},
16-
alg::AbstractSimpleNonlinearSolveAlgorithm,
17-
args...;
18-
kwargs...) where {T, V, P, iip}
19-
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
20-
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
21-
return SciMLBase.build_solution(
22-
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
1+
for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2+
@eval function SciMLBase.solve(
3+
prob::$(pType){<:Union{Number, <:AbstractArray}, iip,
4+
<:Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}}},
5+
alg::AbstractSimpleNonlinearSolveAlgorithm,
6+
args...;
7+
kwargs...) where {T, V, P, iip}
8+
sol, partials = __nlsolve_ad(prob, alg, args...; kwargs...)
9+
dual_soln = __nlsolve_dual_soln(sol.u, partials, prob.p)
10+
return SciMLBase.build_solution(
11+
prob, alg, dual_soln, sol.resid; sol.retcode, sol.stats, sol.original)
12+
end
2313
end
2414

2515
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -47,8 +37,7 @@ function __nlsolve_ad(
4737
tspan = value.(prob.tspan)
4838
newprob = IntervalNonlinearProblem(prob.f, tspan, p; prob.kwargs...)
4939
else
50-
u0 = value(prob.u0)
51-
newprob = NonlinearProblem(prob.f, u0, p; prob.kwargs...)
40+
newprob = remake(prob; p, u0=value(prob.u0))
5241
end
5342

5443
sol = solve(newprob, alg, args...; kwargs...)
@@ -73,20 +62,16 @@ function __nlsolve_ad(
7362
end
7463

7564
function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs...)
76-
p = value(prob.p)
77-
u0 = value(prob.u0)
78-
newprob = NonlinearLeastSquaresProblem(prob.f, u0, p; prob.kwargs...)
79-
65+
newprob = remake(prob; p=value(prob.p), u0=value(prob.u0))
8066
sol = solve(newprob, alg, args...; kwargs...)
81-
8267
uu = sol.u
8368

8469
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
8570
# nested autodiff as the last resort
8671
if SciMLBase.has_vjp(prob.f)
8772
if isinplace(prob)
8873
_F = @closure (du, u, p) -> begin
89-
resid = similar(du, length(sol.resid))
74+
resid = __similar(du, length(sol.resid))
9075
prob.f(resid, u, p)
9176
prob.f.vjp(du, resid, u, p)
9277
du .*= 2
@@ -101,9 +86,9 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
10186
elseif SciMLBase.has_jac(prob.f)
10287
if isinplace(prob)
10388
_F = @closure (du, u, p) -> begin
104-
J = similar(du, length(sol.resid), length(u))
89+
J = __similar(du, length(sol.resid), length(u))
10590
prob.f.jac(J, u, p)
106-
resid = similar(du, length(sol.resid))
91+
resid = __similar(du, length(sol.resid))
10792
prob.f(resid, u, p)
10893
mul!(reshape(du, 1, :), vec(resid)', J, 2, false)
10994
return nothing
@@ -116,43 +101,38 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
116101
else
117102
if isinplace(prob)
118103
_F = @closure (du, u, p) -> begin
119-
resid = similar(du, length(sol.resid))
120-
res = DiffResults.DiffResult(
121-
resid, similar(du, length(sol.resid), length(u)))
122104
_f = @closure (du, u) -> prob.f(du, u, p)
123-
ForwardDiff.jacobian!(res, _f, resid, u)
124-
mul!(reshape(du, 1, :), vec(DiffResults.value(res))',
125-
DiffResults.jacobian(res), 2, false)
105+
resid = __similar(du, length(sol.resid))
106+
v, J = DI.value_and_jacobian(_f, resid, AutoForwardDiff(), u)
107+
mul!(reshape(du, 1, :), vec(v)', J, 2, false)
126108
return nothing
127109
end
128110
else
129111
# For small problems, nesting ForwardDiff is actually quite fast
112+
_f = Base.Fix2(prob.f, newprob.p)
130113
if __is_extension_loaded(Val(:Zygote)) && (length(uu) + length(sol.resid) 50)
131-
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(prob.f, u, p)
114+
# TODO: Remove once DI has the value_and_pullback_split defined
115+
_F = @closure (u, p) -> __zygote_compute_nlls_vjp(_f, u, p)
132116
else
133117
_F = @closure (u, p) -> begin
134-
T = promote_type(eltype(u), eltype(p))
135-
res = DiffResults.DiffResult(similar(u, T, size(sol.resid)),
136-
similar(u, T, length(sol.resid), length(u)))
137-
ForwardDiff.jacobian!(res, Base.Fix2(prob.f, p), u)
138-
return reshape(
139-
2 .* vec(DiffResults.value(res))' * DiffResults.jacobian(res),
140-
size(u))
118+
_f = Base.Fix2(prob.f, p)
119+
v, J = DI.value_and_jacobian(_f, AutoForwardDiff(), u)
120+
return reshape(2 .* vec(v)' * J, size(u))
141121
end
142122
end
143123
end
144124
end
145125

146-
f_p = __nlsolve_∂f_∂p(prob, _F, uu, p)
147-
f_x = __nlsolve_∂f_∂u(prob, _F, uu, p)
126+
f_p = __nlsolve_∂f_∂p(prob, _F, uu, newprob.p)
127+
f_x = __nlsolve_∂f_∂u(prob, _F, uu, newprob.p)
148128

149129
z_arr = -f_x \ f_p
150130

151131
pp = prob.p
152132
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
153133
if uu isa Number
154134
partials = sum(sumfun, zip(z_arr, pp))
155-
elseif p isa Number
135+
elseif pp isa Number
156136
partials = sumfun((z_arr, pp))
157137
else
158138
partials = sum(sumfun, zip(eachcol(z_arr), pp))
@@ -164,7 +144,7 @@ end
164144
@inline function __nlsolve_∂f_∂p(prob, f::F, u, p) where {F}
165145
if isinplace(prob)
166146
__f = p -> begin
167-
du = similar(u, promote_type(eltype(u), eltype(p)))
147+
du = __similar(u, promote_type(eltype(u), eltype(p)))
168148
f(du, u, p)
169149
return du
170150
end
@@ -182,16 +162,12 @@ end
182162

183163
@inline function __nlsolve_∂f_∂u(prob, f::F, u, p) where {F}
184164
if isinplace(prob)
185-
du = similar(u)
186-
__f = (du, u) -> f(du, u, p)
187-
ForwardDiff.jacobian(__f, du, u)
165+
__f = @closure (du, u) -> f(du, u, p)
166+
return ForwardDiff.jacobian(__f, __similar(u), u)
188167
else
189168
__f = Base.Fix2(f, p)
190-
if u isa Number
191-
return ForwardDiff.derivative(__f, u)
192-
else
193-
return ForwardDiff.jacobian(__f, u)
194-
end
169+
u isa Number && return ForwardDiff.derivative(__f, u)
170+
return ForwardDiff.jacobian(__f, u)
195171
end
196172
end
197173

src/nlsolve/halley.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,9 @@ function SciMLBase.__solve(prob::NonlinearProblem, alg::SimpleHalley, args...;
3939
@bb xo = copy(x)
4040

4141
if setindex_trait(x) === CanSetindex()
42-
A = similar(x, length(x), length(x))
43-
Aaᵢ = similar(x, length(x))
44-
cᵢ = similar(x)
42+
A = __similar(x, length(x), length(x))
43+
Aaᵢ = __similar(x, length(x))
44+
cᵢ = __similar(x)
4545
else
4646
A = x
4747
Aaᵢ = x

src/nlsolve/lbroyden.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ end
272272
return :(return SVector{$N, $T}(($(getcalls...))))
273273
end
274274

275-
__lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = similar(x, threshold)
275+
__lbroyden_threshold_cache(x, ::Val{threshold}) where {threshold} = __similar(x, threshold)
276276
function __lbroyden_threshold_cache(x::StaticArray, ::Val{threshold}) where {threshold}
277277
return zeros(MArray{Tuple{threshold}, eltype(x)})
278278
end
@@ -298,7 +298,7 @@ end
298298
end
299299
end
300300
function __init_low_rank_jacobian(u, fu, ::Val{threshold}) where {threshold}
301-
Vᵀ = similar(u, threshold, length(u))
302-
U = similar(u, length(fu), threshold)
301+
Vᵀ = __similar(u, threshold, length(u))
302+
U = __similar(u, length(fu), threshold)
303303
return U, Vᵀ
304304
end

src/utils.jl

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,13 @@ Return the maximum of `a` and `b` if `x1 > x0`, otherwise return the minimum.
2121
"""
2222
__max_tdir(a, b, x0, x1) = ifelse(x1 > x0, max(a, b), min(a, b))
2323

24-
function __fixed_parameter_function(prob::NonlinearProblem)
24+
function __fixed_parameter_function(prob::AbstractNonlinearProblem)
2525
isinplace(prob) && return @closure (du, u) -> prob.f(du, u, prob.p)
2626
return Base.Fix2(prob.f, prob.p)
2727
end
2828

2929
function value_and_jacobian(
30-
ad, prob::NonlinearProblem, f::F, y, x, cache; J = nothing) where {F}
30+
ad, prob::AbstractNonlinearProblem, f::F, y, x, cache; J = nothing) where {F}
3131
x isa Number && return DI.value_and_derivative(f, ad, x, cache)
3232

3333
if isinplace(prob)
@@ -46,29 +46,30 @@ function value_and_jacobian(
4646
end
4747
end
4848

49-
function jacobian_cache(ad, prob::NonlinearProblem, f::F, y, x) where {F}
49+
function jacobian_cache(ad, prob::AbstractNonlinearProblem, f::F, y, x) where {F}
5050
x isa Number && return (nothing, DI.prepare_derivative(f, ad, x))
5151

5252
if isinplace(prob)
53-
J = similar(y, length(y), length(x))
53+
J = __similar(y, length(y), length(x))
5454
SciMLBase.has_jac(prob.f) && return J, HasAnalyticJacobian()
5555
return J, DI.prepare_jacobian(f, y, ad, x)
5656
else
5757
SciMLBase.has_jac(prob.f) && return nothing, HasAnalyticJacobian()
58-
J = ArrayInterface.can_setindex(x) ? similar(y, length(y), length(x)) : nothing
58+
J = ArrayInterface.can_setindex(x) ? __similar(y, length(y), length(x)) : nothing
5959
return J, DI.prepare_jacobian(f, ad, x)
6060
end
6161
end
6262

63-
function compute_jacobian_and_hessian(ad, prob::NonlinearProblem, f::F, y, x) where {F}
63+
function compute_jacobian_and_hessian(
64+
ad, prob::AbstractNonlinearProblem, f::F, y, x) where {F}
6465
if x isa Number
6566
df = @closure x -> DI.derivative(f, ad, x)
6667
return f(x), df(x), DI.derivative(df, ad, x)
6768
end
6869

6970
if isinplace(prob)
7071
df = @closure x -> begin
71-
res = similar(y, promote_type(eltype(y), eltype(x)))
72+
res = __similar(y, promote_type(eltype(y), eltype(x)))
7273
return DI.jacobian(f, res, ad, x)
7374
end
7475
J, H = DI.value_and_jacobian(df, ad, x)
@@ -83,7 +84,7 @@ end
8384
__init_identity_jacobian(u::Number, fu, α = true) = oftype(u, α)
8485
__init_identity_jacobian!!(J::Number) = one(J)
8586
function __init_identity_jacobian(u, fu, α = true)
86-
J = similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
87+
J = __similar(u, promote_type(eltype(u), eltype(fu)), length(fu), length(u))
8788
fill!(J, zero(eltype(J)))
8889
J[diagind(J)] .= eltype(J)(α)
8990
return J
@@ -129,7 +130,7 @@ end
129130
T = eltype(x)
130131
return T.(f.resid_prototype)
131132
else
132-
fx = similar(x)
133+
fx = __similar(x)
133134
f(fx, x, p)
134135
return fx
135136
end
@@ -242,3 +243,16 @@ end
242243

243244
# Extension
244245
function __zygote_compute_nlls_vjp end
246+
247+
function __similar(x, args...; kwargs...)
248+
y = similar(x, args...; kwargs...)
249+
return __init_bigfloat_array!!(y)
250+
end
251+
252+
function __init_bigfloat_array!!(x)
253+
if ArrayInterface.can_setindex(x)
254+
eltype(x) <: BigFloat && fill!(x, BigFloat(0))
255+
return x
256+
end
257+
return x
258+
end

0 commit comments

Comments
 (0)