1
- function SciMLBase. solve (
2
- prob:: NonlinearProblem {<: Union{Number, <:AbstractArray} , iip,
3
- <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
4
- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
5
- args... ;
6
- kwargs... ) where {T, V, P, iip}
7
- sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
8
- dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
9
- return SciMLBase. build_solution (
10
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
11
- end
12
-
13
- function SciMLBase. solve (
14
- prob:: NonlinearLeastSquaresProblem {
15
- <: AbstractArray , iip, <: Union{<:AbstractArray{<:Dual{T, V, P}}} },
16
- alg:: AbstractSimpleNonlinearSolveAlgorithm ,
17
- args... ;
18
- kwargs... ) where {T, V, P, iip}
19
- sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
20
- dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
21
- return SciMLBase. build_solution (
22
- prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
1
+ for pType in (NonlinearProblem, NonlinearLeastSquaresProblem)
2
+ @eval function SciMLBase. solve (
3
+ prob:: $ (pType){<: Union{Number, <:AbstractArray} , iip,
4
+ <: Union{<:Dual{T, V, P}, <:AbstractArray{<:Dual{T, V, P}}} },
5
+ alg:: AbstractSimpleNonlinearSolveAlgorithm ,
6
+ args... ;
7
+ kwargs... ) where {T, V, P, iip}
8
+ sol, partials = __nlsolve_ad (prob, alg, args... ; kwargs... )
9
+ dual_soln = __nlsolve_dual_soln (sol. u, partials, prob. p)
10
+ return SciMLBase. build_solution (
11
+ prob, alg, dual_soln, sol. resid; sol. retcode, sol. stats, sol. original)
12
+ end
23
13
end
24
14
25
15
for algType in (Bisection, Brent, Alefeld, Falsi, ITP, Ridder)
@@ -47,8 +37,7 @@ function __nlsolve_ad(
47
37
tspan = value .(prob. tspan)
48
38
newprob = IntervalNonlinearProblem (prob. f, tspan, p; prob. kwargs... )
49
39
else
50
- u0 = value (prob. u0)
51
- newprob = NonlinearProblem (prob. f, u0, p; prob. kwargs... )
40
+ newprob = remake (prob; p, u0= value (prob. u0))
52
41
end
53
42
54
43
sol = solve (newprob, alg, args... ; kwargs... )
@@ -73,20 +62,16 @@ function __nlsolve_ad(
73
62
end
74
63
75
64
function __nlsolve_ad (prob:: NonlinearLeastSquaresProblem , alg, args... ; kwargs... )
76
- p = value (prob. p)
77
- u0 = value (prob. u0)
78
- newprob = NonlinearLeastSquaresProblem (prob. f, u0, p; prob. kwargs... )
79
-
65
+ newprob = remake (prob; p= value (prob. p), u0= value (prob. u0))
80
66
sol = solve (newprob, alg, args... ; kwargs... )
81
-
82
67
uu = sol. u
83
68
84
69
# First check for custom `vjp` then custom `Jacobian` and if nothing is provided use
85
70
# nested autodiff as the last resort
86
71
if SciMLBase. has_vjp (prob. f)
87
72
if isinplace (prob)
88
73
_F = @closure (du, u, p) -> begin
89
- resid = similar (du, length (sol. resid))
74
+ resid = __similar (du, length (sol. resid))
90
75
prob. f (resid, u, p)
91
76
prob. f. vjp (du, resid, u, p)
92
77
du .*= 2
@@ -101,9 +86,9 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
101
86
elseif SciMLBase. has_jac (prob. f)
102
87
if isinplace (prob)
103
88
_F = @closure (du, u, p) -> begin
104
- J = similar (du, length (sol. resid), length (u))
89
+ J = __similar (du, length (sol. resid), length (u))
105
90
prob. f. jac (J, u, p)
106
- resid = similar (du, length (sol. resid))
91
+ resid = __similar (du, length (sol. resid))
107
92
prob. f (resid, u, p)
108
93
mul! (reshape (du, 1 , :), vec (resid)' , J, 2 , false )
109
94
return nothing
@@ -116,43 +101,38 @@ function __nlsolve_ad(prob::NonlinearLeastSquaresProblem, alg, args...; kwargs..
116
101
else
117
102
if isinplace (prob)
118
103
_F = @closure (du, u, p) -> begin
119
- resid = similar (du, length (sol. resid))
120
- res = DiffResults. DiffResult (
121
- resid, similar (du, length (sol. resid), length (u)))
122
104
_f = @closure (du, u) -> prob. f (du, u, p)
123
- ForwardDiff . jacobian! (res, _f, resid, u )
124
- mul! ( reshape (du, 1 , :), vec (DiffResults . value (res)) ' ,
125
- DiffResults . jacobian (res) , 2 , false )
105
+ resid = __similar (du, length (sol . resid) )
106
+ v, J = DI . value_and_jacobian (_f, resid, AutoForwardDiff (), u)
107
+ mul! ( reshape (du, 1 , :), vec (v) ' , J , 2 , false )
126
108
return nothing
127
109
end
128
110
else
129
111
# For small problems, nesting ForwardDiff is actually quite fast
112
+ _f = Base. Fix2 (prob. f, newprob. p)
130
113
if __is_extension_loaded (Val (:Zygote )) && (length (uu) + length (sol. resid) ≥ 50 )
131
- _F = @closure (u, p) -> __zygote_compute_nlls_vjp (prob. f, u, p)
114
+ # TODO : Remove once DI has the value_and_pullback_split defined
115
+ _F = @closure (u, p) -> __zygote_compute_nlls_vjp (_f, u, p)
132
116
else
133
117
_F = @closure (u, p) -> begin
134
- T = promote_type (eltype (u), eltype (p))
135
- res = DiffResults. DiffResult (similar (u, T, size (sol. resid)),
136
- similar (u, T, length (sol. resid), length (u)))
137
- ForwardDiff. jacobian! (res, Base. Fix2 (prob. f, p), u)
138
- return reshape (
139
- 2 .* vec (DiffResults. value (res))' * DiffResults. jacobian (res),
140
- size (u))
118
+ _f = Base. Fix2 (prob. f, p)
119
+ v, J = DI. value_and_jacobian (_f, AutoForwardDiff (), u)
120
+ return reshape (2 .* vec (v)' * J, size (u))
141
121
end
142
122
end
143
123
end
144
124
end
145
125
146
- f_p = __nlsolve_∂f_∂p (prob, _F, uu, p)
147
- f_x = __nlsolve_∂f_∂u (prob, _F, uu, p)
126
+ f_p = __nlsolve_∂f_∂p (prob, _F, uu, newprob . p)
127
+ f_x = __nlsolve_∂f_∂u (prob, _F, uu, newprob . p)
148
128
149
129
z_arr = - f_x \ f_p
150
130
151
131
pp = prob. p
152
132
sumfun = ((z, p),) -> map (zᵢ -> zᵢ * ForwardDiff. partials (p), z)
153
133
if uu isa Number
154
134
partials = sum (sumfun, zip (z_arr, pp))
155
- elseif p isa Number
135
+ elseif pp isa Number
156
136
partials = sumfun ((z_arr, pp))
157
137
else
158
138
partials = sum (sumfun, zip (eachcol (z_arr), pp))
164
144
@inline function __nlsolve_∂f_∂p (prob, f:: F , u, p) where {F}
165
145
if isinplace (prob)
166
146
__f = p -> begin
167
- du = similar (u, promote_type (eltype (u), eltype (p)))
147
+ du = __similar (u, promote_type (eltype (u), eltype (p)))
168
148
f (du, u, p)
169
149
return du
170
150
end
@@ -182,16 +162,12 @@ end
182
162
183
163
@inline function __nlsolve_∂f_∂u (prob, f:: F , u, p) where {F}
184
164
if isinplace (prob)
185
- du = similar (u)
186
- __f = (du, u) -> f (du, u, p)
187
- ForwardDiff. jacobian (__f, du, u)
165
+ __f = @closure (du, u) -> f (du, u, p)
166
+ return ForwardDiff. jacobian (__f, __similar (u), u)
188
167
else
189
168
__f = Base. Fix2 (f, p)
190
- if u isa Number
191
- return ForwardDiff. derivative (__f, u)
192
- else
193
- return ForwardDiff. jacobian (__f, u)
194
- end
169
+ u isa Number && return ForwardDiff. derivative (__f, u)
170
+ return ForwardDiff. jacobian (__f, u)
195
171
end
196
172
end
197
173
0 commit comments