Skip to content

Commit b8bed15

Browse files
Merge pull request #28 from utkarsh530/u/buildsolution
Add SciML.build_solution
2 parents f232c8f + f081339 commit b8bed15

File tree

5 files changed

+38
-71
lines changed

5 files changed

+38
-71
lines changed

src/scalar.jl

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, args...; xatol = nothing, xrtol = nothing, maxiters = 1000, kwargs...)
22
f = Base.Fix2(prob.f, prob.p)
33
x = float(prob.u0)
4+
fx = float(prob.u0)
45
T = typeof(x)
56
atol = xatol !== nothing ? xatol : oneunit(T) * (eps(one(T)))^(4//5)
67
rtol = xrtol !== nothing ? xrtol : eps(one(T))^(4//5)
@@ -13,15 +14,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Number}, alg::NewtonRaphson, a
1314
fx = f(x)
1415
dfx = FiniteDiff.finite_difference_derivative(f, x, alg.diff_type, eltype(x), fx)
1516
end
16-
iszero(fx) && return NewtonSolution(x, DEFAULT)
17+
iszero(fx) && return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
1718
Δx = dfx \ fx
1819
x -= Δx
1920
if isapprox(x, xo, atol=atol, rtol=rtol)
20-
return NewtonSolution(x, DEFAULT)
21+
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(DEFAULT))
2122
end
2223
xo = x
2324
end
24-
return NewtonSolution(x, MAXITERS_EXCEED)
25+
return SciMLBase.build_solution(prob, alg, x, fx; retcode=Symbol(MAXITERS_EXCEED))
2526
end
2627

2728
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
@@ -32,7 +33,7 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
3233
newprob = NonlinearProblem(f, u0, p; prob.kwargs...)
3334
sol = solve(newprob, alg, args...; kwargs...)
3435

35-
uu = getsolution(sol)
36+
uu = sol.u
3637
if p isa Number
3738
f_p = ForwardDiff.derivative(Base.Fix1(f, uu), p)
3839
else
@@ -50,39 +51,42 @@ end
5051

5152
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:Dual{T,V,P}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5253
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
53-
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
54+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
55+
5456
end
5557
function SciMLBase.solve(prob::NonlinearProblem{<:Number, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::NewtonRaphson, args...; kwargs...) where {iip, T, V, P}
5658
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
57-
return NewtonSolution(Dual{T,V,P}(sol.u, partials), sol.retcode)
59+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode)
5860
end
5961

6062
# avoid ambiguities
6163
for Alg in [Bisection]
6264
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:Dual{T,V,P}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
6365
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
64-
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
66+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
67+
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
6568
end
6669
@eval function SciMLBase.solve(prob::NonlinearProblem{uType, iip, <:AbstractArray{<:Dual{T,V,P}}}, alg::$Alg, args...; kwargs...) where {uType, iip, T, V, P}
6770
sol, partials = scalar_nlsolve_ad(prob, alg, args...; kwargs...)
68-
return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode)
71+
return SciMLBase.build_solution(prob, alg, Dual{T,V,P}(sol.u, partials), sol.resid; retcode=sol.retcode,left = Dual{T,V,P}(sol.left, partials), right = Dual{T,V,P}(sol.right, partials))
72+
#return BracketingSolution(Dual{T,V,P}(sol.left, partials), Dual{T,V,P}(sol.right, partials), sol.retcode, sol.resid)
6973
end
7074
end
7175

72-
function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters = 1000, kwargs...)
76+
function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxiters = 1000, kwargs...)
7377
f = Base.Fix2(prob.f, prob.p)
7478
left, right = prob.u0
7579
fl, fr = f(left), f(right)
7680

7781
if iszero(fl)
78-
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
82+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
7983
end
8084

8185
i = 1
8286
if !iszero(fr)
8387
while i < maxiters
8488
mid = (left + right) / 2
85-
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
89+
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
8690
fm = f(mid)
8791
if iszero(fm)
8892
right = mid
@@ -101,7 +105,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
101105

102106
while i < maxiters
103107
mid = (left + right) / 2
104-
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
108+
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
105109
fm = f(mid)
106110
if iszero(fm)
107111
right = mid
@@ -113,23 +117,23 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Bisection, args...; maxiters
113117
i += 1
114118
end
115119

116-
return BracketingSolution(left, right, MAXITERS_EXCEED)
120+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
117121
end
118122

119-
function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 1000, kwargs...)
123+
function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters = 1000, kwargs...)
120124
f = Base.Fix2(prob.f, prob.p)
121125
left, right = prob.u0
122126
fl, fr = f(left), f(right)
123127

124128
if iszero(fl)
125-
return BracketingSolution(left, right, EXACT_SOLUTION_LEFT)
129+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(EXACT_SOLUTION_LEFT), left = left, right = right)
126130
end
127131

128132
i = 1
129133
if !iszero(fr)
130134
while i < maxiters
131135
if nextfloat_tdir(left, prob.u0...) == right
132-
return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
136+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
133137
end
134138
mid = (fr * left - fl * right) / (fr - fl)
135139
for i in 1:10
@@ -156,7 +160,7 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10
156160

157161
while i < maxiters
158162
mid = (left + right) / 2
159-
(mid == left || mid == right) && return BracketingSolution(left, right, FLOATING_POINT_LIMIT)
163+
(mid == left || mid == right) && return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(FLOATING_POINT_LIMIT), left = left, right = right)
160164
fm = f(mid)
161165
if iszero(fm)
162166
right = mid
@@ -171,5 +175,5 @@ function SciMLBase.solve(prob::NonlinearProblem, ::Falsi, args...; maxiters = 10
171175
i += 1
172176
end
173177

174-
return BracketingSolution(left, right, MAXITERS_EXCEED)
178+
return SciMLBase.build_solution(prob, alg, left, fl; retcode=Symbol(MAXITERS_EXCEED), left = left, right = right)
175179
end

src/solve.jl

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ function SciMLBase.solve(prob::NonlinearProblem,
33
kwargs...)
44
solver = init(prob, alg, args...; kwargs...)
55
sol = solve!(solver)
6-
return sol
76
end
87

98
function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracketingAlgorithm, args...;
@@ -30,7 +29,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractBracket
3029
fl = f(left, p)
3130
fr = f(right, p)
3231
cache = alg_cache(alg, left, right,p, Val(iip))
33-
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip)
32+
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters, DEFAULT, cache, iip,prob)
3433
end
3534

3635
function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm, args...;
@@ -55,7 +54,7 @@ function SciMLBase.init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonA
5554
fu = f(u, p)
5655
end
5756
cache = alg_cache(alg, f, u, p, Val(iip))
58-
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip)
57+
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm, DEFAULT, tol, cache, iip, prob)
5958
end
6059

6160
function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
@@ -67,8 +66,11 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
6766
if solver.iter == solver.maxiters
6867
@set! solver.retcode = MAXITERS_EXCEED
6968
end
70-
sol = get_solution(solver)
71-
return sol
69+
if typeof(solver) <: NewtonImmutableSolver
70+
SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu;retcode=Symbol(solver.retcode))
71+
else
72+
SciMLBase.build_solution(solver.prob, solver.alg, solver.left,solver.fl;retcode=Symbol(solver.retcode),left = solver.left,right = solver.right)
73+
end
7274
end
7375

7476
"""
@@ -96,20 +98,6 @@ function mic_check(solver::NewtonImmutableSolver)
9698
solver
9799
end
98100

99-
"""
100-
get_solution(solver::Union{BracketingImmutableSolver, BracketingSolver})
101-
get_solution(solver::Union{NewtonImmutableSolver, NewtonSolver})
102-
103-
Form solution object from solver types
104-
"""
105-
function get_solution(solver::BracketingImmutableSolver)
106-
return BracketingSolution(solver.left, solver.right, solver.retcode)
107-
end
108-
109-
function get_solution(solver::NewtonImmutableSolver)
110-
return NewtonSolution(solver.u, solver.retcode)
111-
end
112-
113101
"""
114102
reinit!(solver, prob)
115103

src/types.jl

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
FLOATING_POINT_LIMIT
77
end
88

9-
struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType} <: AbstractImmutableNonlinearSolver
9+
struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType} <: AbstractImmutableNonlinearSolver
1010
iter::Int
1111
f::fType
1212
alg::algType
@@ -20,14 +20,15 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
2020
retcode::Retcode
2121
cache::cacheType
2222
iip::Bool
23+
prob::probType
2324
end
2425

2526
# function BracketingImmutableSolver(iip, iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
2627
# BracketingImmutableSolver{iip, typeof(f), typeof(alg),
2728
# typeof(left), typeof(fl), typeof(p), typeof(cache)}(iter, f, alg, left, right, fl, fr, p, force_stop, maxiters, retcode, cache)
2829
# end
2930

30-
struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType} <: AbstractImmutableNonlinearSolver
31+
struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolType, cacheType, probType} <: AbstractImmutableNonlinearSolver
3132
iter::Int
3233
f::fType
3334
alg::algType
@@ -41,29 +42,17 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT
4142
tol::tolType
4243
cache::cacheType
4344
iip::Bool
45+
prob::probType
4446
end
4547

4648
# function NewtonImmutableSolver{iip}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache) where iip
4749
# NewtonImmutableSolver{iip, typeof(f), typeof(alg), typeof(u),
4850
# typeof(fu), typeof(p), typeof(internalnorm), typeof(tol), typeof(cache)}(iter, f, alg, u, fu, p, force_stop, maxiters, internalnorm, retcode, tol, cache)
4951
# end
5052

51-
struct BracketingSolution{uType}
52-
left::uType
53-
right::uType
54-
retcode::Retcode
55-
end
56-
57-
struct NewtonSolution{uType}
58-
u::uType
59-
retcode::Retcode
60-
end
6153

6254
function sync_residuals!(solver::BracketingImmutableSolver)
6355
@set! solver.fl = solver.f(solver.left, solver.p)
6456
@set! solver.fr = solver.f(solver.right, solver.p)
6557
solver
66-
end
67-
68-
getsolution(sol::NewtonSolution) = sol.u
69-
getsolution(sol::BracketingSolution) = sol.left
58+
end

src/utils.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -101,21 +101,6 @@ function num_types_in_tuple(sig::UnionAll)
101101
length(Base.unwrap_unionall(sig).parameters)
102102
end
103103

104-
function numargs(f)
105-
typ = Tuple{Any, Val{:analytic}, Vararg}
106-
typ2 = Tuple{Any, Type{Val{:analytic}}, Vararg} # This one is required for overloaded types
107-
typ3 = Tuple{Any, Val{:jac}, Vararg}
108-
typ4 = Tuple{Any, Type{Val{:jac}}, Vararg} # This one is required for overloaded types
109-
typ5 = Tuple{Any, Val{:tgrad}, Vararg}
110-
typ6 = Tuple{Any, Type{Val{:tgrad}}, Vararg} # This one is required for overloaded types
111-
numparam = maximum([(m.sig<:typ || m.sig<:typ2 || m.sig<:typ3 || m.sig<:typ4 || m.sig<:typ5 || m.sig<:typ6) ? 0 : num_types_in_tuple(m.sig) for m in methods(f)])
112-
return (numparam-1) #-1 in v0.5 since it adds f as the first parameter
113-
end
114-
115-
function isinplace(f,inplace_param_number)
116-
numargs(f)>=inplace_param_number
117-
end
118-
119104
### Default Linsolve
120105

121106
# Try to be as smart as possible

test/runtests.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,13 @@ end
2323
f, u0 = (u,p) -> u .* u .- 2, @SVector[1.0, 1.0]
2424
sf, su0 = (u,p) -> u * u - 2, 1.0
2525
sol = benchmark_immutable(f, u0)
26-
@test sol.retcode === NonlinearSolve.DEFAULT
26+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
2727
@test all(sol.u .* sol.u .- 2 .< 1e-9)
2828
sol = benchmark_mutable(f, u0)
29-
@test sol.retcode === NonlinearSolve.DEFAULT
29+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
3030
@test all(sol.u .* sol.u .- 2 .< 1e-9)
3131
sol = benchmark_scalar(sf, su0)
32-
@test sol.retcode === NonlinearSolve.DEFAULT
32+
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
3333
@test sol.u * sol.u - 2 < 1e-9
3434

3535
@test (@ballocated benchmark_immutable($f, $u0)) == 0
@@ -117,6 +117,7 @@ probN = NonlinearProblem(f, u0)
117117
@test solve(probN, NewtonRaphson(;autodiff=false); immutable = false).u[end] sqrt(2.0)
118118

119119
for u0 in [1.0, [1, 1.0]]
120+
local f, probN, sol
120121
f = (u, p) -> u .* u .- 2.0
121122
probN = NonlinearProblem(f, u0)
122123
sol = sqrt(2) * u0

0 commit comments

Comments
 (0)