Skip to content

Commit 54c7da2

Browse files
Merge pull request #92 from utkarsh530/u/nlgpufix
Change NonlinearSolve enums to SciMLBase enums
2 parents 58f23c2 + 2e26ff2 commit 54c7da2

File tree

6 files changed

+26
-34
lines changed

6 files changed

+26
-34
lines changed

src/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ function perform_step(solver::BracketingImmutableSolver, alg::Bisection, cache)
3535

3636
if left == mid || right == mid
3737
@set! solver.force_stop = true
38-
@set! solver.retcode = FLOATING_POINT_LIMIT
38+
@set! solver.retcode = ReturnCode.FloatingPointLimit
3939
return solver
4040
end
4141

src/falsi.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ function perform_step(solver, alg::Falsi, cache)
2121

2222
if right == mid || right == mid
2323
@set! solver.force_stop = true
24-
@set! solver.retcode = FLOATING_POINT_LIMIT
24+
@set! solver.retcode = ReturnCode.FloatingPointLimit
2525
return solver
2626
end
2727

@@ -32,7 +32,7 @@ function perform_step(solver, alg::Falsi, cache)
3232
@set! solver.force_stop = true
3333
@set! solver.left = mid
3434
@set! solver.fl = fm
35-
@set! solver.retcode = EXACT_SOLUTION_LEFT
35+
@set! solver.retcode = ReturnCode.ExactSolutionLeft
3636
else
3737
if sign(fm) == sign(fl)
3838
@set! solver.left = mid

src/scalar.jl

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -26,15 +26,15 @@ function SciMLBase.solve(prob::NonlinearProblem{<:Union{Number, SVector}},
2626
fx)
2727
end
2828
iszero(fx) &&
29-
return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT))
29+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default)
3030
Δx = dfx \ fx
3131
x -= Δx
3232
if isapprox(x, xo, atol = atol, rtol = rtol)
33-
return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(DEFAULT))
33+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.Default)
3434
end
3535
xo = x
3636
end
37-
return SciMLBase.build_solution(prob, alg, x, fx; retcode = Symbol(MAXITERS_EXCEED))
37+
return SciMLBase.build_solution(prob, alg, x, fx; retcode = ReturnCode.MaxIters)
3838
end
3939

4040
function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
@@ -109,7 +109,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite
109109

110110
if iszero(fl)
111111
return SciMLBase.build_solution(prob, alg, left, fl;
112-
retcode = Symbol(EXACT_SOLUTION_LEFT), left = left,
112+
retcode = ReturnCode.ExactSolutionLeft, left = left,
113113
right = right)
114114
end
115115

@@ -119,7 +119,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite
119119
mid = (left + right) / 2
120120
(mid == left || mid == right) &&
121121
return SciMLBase.build_solution(prob, alg, left, fl;
122-
retcode = Symbol(FLOATING_POINT_LIMIT),
122+
retcode = ReturnCode.FloatingPointLimit,
123123
left = left, right = right)
124124
fm = f(mid)
125125
if iszero(fm)
@@ -141,7 +141,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite
141141
mid = (left + right) / 2
142142
(mid == left || mid == right) &&
143143
return SciMLBase.build_solution(prob, alg, left, fl;
144-
retcode = Symbol(FLOATING_POINT_LIMIT),
144+
retcode = ReturnCode.FloatingPointLimit,
145145
left = left, right = right)
146146
fm = f(mid)
147147
if iszero(fm)
@@ -154,7 +154,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Bisection, args...; maxite
154154
i += 1
155155
end
156156

157-
return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED),
157+
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
158158
left = left, right = right)
159159
end
160160

@@ -166,7 +166,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
166166

167167
if iszero(fl)
168168
return SciMLBase.build_solution(prob, alg, left, fl;
169-
retcode = Symbol(EXACT_SOLUTION_LEFT), left = left,
169+
retcode = ReturnCode.ExactSolutionLeft, left = left,
170170
right = right)
171171
end
172172

@@ -175,7 +175,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
175175
while i < maxiters
176176
if nextfloat_tdir(left, prob.u0...) == right
177177
return SciMLBase.build_solution(prob, alg, left, fl;
178-
retcode = Symbol(FLOATING_POINT_LIMIT),
178+
retcode = ReturnCode.FloatingPointLimit,
179179
left = left, right = right)
180180
end
181181
mid = (fr * left - fl * right) / (fr - fl)
@@ -205,7 +205,7 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
205205
mid = (left + right) / 2
206206
(mid == left || mid == right) &&
207207
return SciMLBase.build_solution(prob, alg, left, fl;
208-
retcode = Symbol(FLOATING_POINT_LIMIT),
208+
retcode = ReturnCode.FloatingPointLimit,
209209
left = left, right = right)
210210
fm = f(mid)
211211
if iszero(fm)
@@ -221,6 +221,6 @@ function SciMLBase.solve(prob::NonlinearProblem, alg::Falsi, args...; maxiters =
221221
i += 1
222222
end
223223

224-
return SciMLBase.build_solution(prob, alg, left, fl; retcode = Symbol(MAXITERS_EXCEED),
224+
return SciMLBase.build_solution(prob, alg, left, fl; retcode = ReturnCode.MaxIters,
225225
left = left, right = right)
226226
end

src/solve.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip},
2929
fr = f(right, p)
3030
cache = alg_cache(alg, left, right, p, Val(iip))
3131
return BracketingImmutableSolver(1, f, alg, left, right, fl, fr, p, false, maxiters,
32-
DEFAULT, cache, iip, prob)
32+
ReturnCode.Default, cache, iip, prob)
3333
end
3434

3535
function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewtonAlgorithm,
@@ -54,7 +54,7 @@ function SciMLBase.__init(prob::NonlinearProblem{uType, iip}, alg::AbstractNewto
5454
end
5555
cache = alg_cache(alg, f, u, p, Val(iip))
5656
return NewtonImmutableSolver(1, f, alg, u, fu, p, false, maxiters, internalnorm,
57-
DEFAULT, tol, cache, iip, prob)
57+
Retcode.Default, tol, cache, iip, prob)
5858
end
5959

6060
function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
@@ -64,14 +64,14 @@ function SciMLBase.solve!(solver::AbstractImmutableNonlinearSolver)
6464
@set! solver.iter += 1
6565
end
6666
if solver.iter == solver.maxiters
67-
@set! solver.retcode = MAXITERS_EXCEED
67+
@set! solver.retcode = ReturnCode.MaxIters
6868
end
6969
if typeof(solver) <: NewtonImmutableSolver
7070
SciMLBase.build_solution(solver.prob, solver.alg, solver.u, solver.fu;
71-
retcode = Symbol(solver.retcode))
71+
retcode = solver.retcode)
7272
else
7373
SciMLBase.build_solution(solver.prob, solver.alg, solver.left, solver.fl;
74-
retcode = Symbol(solver.retcode), left = solver.left,
74+
retcode = solver.retcode, left = solver.left,
7575
right = solver.right)
7676
end
7777
end
@@ -89,10 +89,10 @@ function mic_check(solver::BracketingImmutableSolver)
8989
(flr > fzero) && error("Non bracketing interval passed in bracketing method.")
9090
if fl == fzero
9191
@set! solver.force_stop = true
92-
@set! solver.retcode = EXACT_SOLUTION_LEFT
92+
@set! solver.retcode = Retcode.ExactSolutionLeft
9393
elseif fr == fzero
9494
@set! solver.force_stop = true
95-
@set! solver.retcode = EXACT_SOLUTION_RIGHT
95+
@set! solver.retcode = Retcode.ExactionSolutionRight
9696
end
9797
solver
9898
end

src/types.jl

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,3 @@
1-
@enum Retcode::Int begin
2-
DEFAULT
3-
EXACT_SOLUTION_LEFT
4-
EXACT_SOLUTION_RIGHT
5-
MAXITERS_EXCEED
6-
FLOATING_POINT_LIMIT
7-
end
8-
91
struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheType, probType
102
} <: AbstractImmutableNonlinearSolver
113
iter::Int
@@ -18,7 +10,7 @@ struct BracketingImmutableSolver{fType, algType, uType, resType, pType, cacheTyp
1810
p::pType
1911
force_stop::Bool
2012
maxiters::Int
21-
retcode::Retcode
13+
retcode::SciMLBase.ReturnCode.T
2214
cache::cacheType
2315
iip::Bool
2416
prob::probType
@@ -40,7 +32,7 @@ struct NewtonImmutableSolver{fType, algType, uType, resType, pType, INType, tolT
4032
force_stop::Bool
4133
maxiters::Int
4234
internalnorm::INType
43-
retcode::Retcode
35+
retcode::SciMLBase.ReturnCode.T
4436
tol::tolType
4537
cache::cacheType
4638
iip::Bool

test/basictests.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,13 @@ end
3030
const csu0 = 1.0
3131

3232
sol = benchmark_immutable(ff, cu0)
33-
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
33+
@test sol.retcode === ReturnCode.Default
3434
@test all(sol.u .* sol.u .- 2 .< 1e-9)
3535
sol = benchmark_mutable(ff, cu0)
36-
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
36+
@test sol.retcode === ReturnCode.Default
3737
@test all(sol.u .* sol.u .- 2 .< 1e-9)
3838
sol = benchmark_scalar(sf, csu0)
39-
@test sol.retcode === Symbol(NonlinearSolve.DEFAULT)
39+
@test sol.retcode === ReturnCode.Default
4040
@test sol.u * sol.u - 2 < 1e-9
4141

4242
@test (@ballocated benchmark_immutable(ff, cu0)) == 0

0 commit comments

Comments
 (0)