Skip to content

Commit b9ae9f8

Browse files
committed
Default the default norm to nothing for upstream selection
1 parent 0e026f3 commit b9ae9f8

File tree

2 files changed

+39
-24
lines changed

2 files changed

+39
-24
lines changed

src/termination_conditions.jl

Lines changed: 27 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,13 @@ const TERM_DOCS = Dict(
5555
:AbsNorm => doc"``\| \frac{\partial u}{\partial t} \| \leq abstol``"
5656
)
5757

58-
const __TERM_INTERNALNORM_DOCS = "where `internalnorm` is the norm to use for the \
59-
termination condition. Special handling is done for \
60-
`norm(_, 2)`, `norm(_, Inf)`, and `maximum(abs, _)`."
58+
const __TERM_INTERNALNORM_DOCS = """
59+
where `internalnorm` is the norm to use for the termination condition. Special handling is
60+
done for `norm(_, 2)`, `norm(_, Inf)`, and `maximum(abs, _)`.
61+
62+
Default is left as `nothing`, which allows upstream frameworks to choose the correct norm
63+
based on the problem type. If directly using the `init` API, a proper norm must be
64+
provided"""
6165

6266
for name in (:Rel, :Abs)
6367
struct_name = Symbol(name, :TerminationMode)
@@ -85,16 +89,14 @@ for name in (:Norm, :RelNorm, :AbsNorm)
8589
8690
## Constructor
8791
88-
$($struct_name)(internalnorm = Base.Fix1(maximum, abs))
92+
$($struct_name)(internalnorm = nothing)
8993
9094
$($__TERM_INTERNALNORM_DOCS).
9195
"""
9296
@concrete struct $(struct_name){F} <: AbstractNonlinearTerminationMode
9397
internalnorm
9498

95-
function $(struct_name)(f::F = Base.Fix1(maximum, abs)) where {F}
96-
return new{__norm_type(f), F}(f)
97-
end
99+
$(struct_name)(f::F = nothing) where {F} = new{__norm_type(f), F}(f)
98100
end
99101
end
100102
end
@@ -118,10 +120,9 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
118120
119121
## Constructor
120122
121-
$($struct_name)(internalnorm = Base.Fix1(maximum, abs);
122-
protective_threshold = nothing, patience_steps = 100,
123-
patience_objective_multiplier = 3, min_max_factor = 1.3,
124-
max_stalled_steps = nothing)
123+
$($struct_name)(internalnorm = nothing; protective_threshold = nothing,
124+
patience_steps = 100, patience_objective_multiplier = 3,
125+
min_max_factor = 1.3, max_stalled_steps = nothing)
125126
126127
$($__TERM_INTERNALNORM_DOCS).
127128
"""
@@ -133,10 +134,9 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
133134
min_max_factor
134135
max_stalled_steps::T
135136

136-
function $(struct_name)(f::F = Base.Fix1(maximum, abs);
137-
protective_threshold = nothing, patience_steps = 100,
138-
patience_objective_multiplier = 3, min_max_factor = 1.3,
139-
max_stalled_steps = nothing) where {F}
137+
function $(struct_name)(f::F = nothing; protective_threshold = nothing,
138+
patience_steps = 100, patience_objective_multiplier = 3,
139+
min_max_factor = 1.3, max_stalled_steps = nothing) where {F}
140140
return new{__norm_type(f), typeof(max_stalled_steps), F,
141141
typeof(protective_threshold), typeof(patience_objective_multiplier),
142142
typeof(min_max_factor)}(f, protective_threshold, patience_steps,
@@ -199,10 +199,10 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
199199
(ArrayInterface.can_setindex(u) ? copy(u) : u) : nothing
200200
if mode isa AbstractSafeNonlinearTerminationMode
201201
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
202-
initial_objective = mode.internalnorm(du)
202+
initial_objective = __apply_termination_internalnorm(mode.internalnorm, du)
203203
u0_norm = nothing
204204
else
205-
initial_objective = mode.internalnorm(du) /
205+
initial_objective = __apply_termination_internalnorm(mode.internalnorm, du) /
206206
(__add_and_norm(mode.internalnorm, du, u) + eps(TT))
207207
u0_norm = mode.max_stalled_steps === nothing ? nothing : norm(u, 2)
208208
end
@@ -255,9 +255,11 @@ function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{dep_retcode}, du
255255
mode = get_termination_mode(cache)
256256
if mode isa AbstractSafeNonlinearTerminationMode
257257
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
258-
initial_objective = cache.mode.internalnorm(du)
258+
initial_objective = __apply_termination_internalnorm(
259+
cache.mode.internalnorm, du)
259260
else
260-
initial_objective = cache.mode.internalnorm(du) /
261+
initial_objective = __apply_termination_internalnorm(
262+
cache.mode.internalnorm, du) /
261263
(__add_and_norm(cache.mode.internalnorm, du, u) + eps(TT))
262264
cache.max_stalled_steps !== nothing && (cache.u0_norm = norm(u_, 2))
263265
end
@@ -292,10 +294,10 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
292294
mode::AbstractSafeNonlinearTerminationMode,
293295
du, u, uprev, args...) where {dep_retcode}
294296
if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
295-
objective = mode.internalnorm(du)
297+
objective = __apply_termination_internalnorm(mode.internalnorm, du)
296298
criteria = cache.abstol
297299
else
298-
objective = mode.internalnorm(du) /
300+
objective = __apply_termination_internalnorm(mode.internalnorm, du) /
299301
(__add_and_norm(mode.internalnorm, du, u) + eps(cache.abstol))
300302
criteria = cache.reltol
301303
end
@@ -427,19 +429,20 @@ function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol
427429
end
428430

429431
function check_convergence(mode::NormTerminationMode, duₙ, uₙ, uₙ₋₁, abstol, reltol)
430-
du_norm = mode.internalnorm(duₙ)
432+
du_norm = __apply_termination_internalnorm(mode.internalnorm, duₙ)
431433
return (du_norm abstol) ||
432434
(du_norm reltol * __add_and_norm(mode.internalnorm, duₙ, uₙ))
433435
end
434436
function check_convergence(
435437
mode::Union{
436438
RelNormTerminationMode, RelSafeTerminationMode, RelSafeBestTerminationMode},
437439
duₙ, uₙ, uₙ₋₁, abstol, reltol)
438-
return mode.internalnorm(duₙ) reltol * __add_and_norm(mode.internalnorm, duₙ, uₙ)
440+
return __apply_termination_internalnorm(mode.internalnorm, duₙ)
441+
reltol * __add_and_norm(mode.internalnorm, duₙ, uₙ)
439442
end
440443
function check_convergence(
441444
mode::Union{AbsNormTerminationMode, AbsSafeTerminationMode,
442445
AbsSafeBestTerminationMode},
443446
duₙ, uₙ, uₙ₋₁, abstol, reltol)
444-
return mode.internalnorm(duₙ) abstol
447+
return __apply_termination_internalnorm(mode.internalnorm, duₙ) abstol
445448
end

src/utils.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,18 @@ function __nonlinearsolve_is_approx(x, y; atol = false,
102102
return d max(atol, rtol * max(maximum(abs, x), maximum(abs, y)))
103103
end
104104

105+
@inline function __add_and_norm(::Nothing, x, y)
106+
Base.depwarn("Not specifying the internal norm of termination conditions has been \
107+
deprecated. Using inf-norm currently.", :__add_and_norm)
108+
return __maximum_abs(+, x, y)
109+
end
105110
@inline __add_and_norm(::typeof(Base.Fix1(maximum, abs)), x, y) = __maximum_abs(+, x, y)
106111
@inline __add_and_norm(::typeof(Base.Fix2(norm, Inf)), x, y) = __maximum_abs(+, x, y)
107112
@inline __add_and_norm(f::F, x, y) where {F} = __norm_op(f, +, x, y)
113+
114+
@inline function __apply_termination_internalnorm(::Nothing, u)
115+
Base.depwarn("Not specifying the internal norm of termination conditions has been \
116+
deprecated. Using inf-norm currently.", :__apply_termination_internalnorm)
117+
return __apply_termination_internalnorm(Base.Fix1(maximum, abs), u)
118+
end
119+
@inline __apply_termination_internalnorm(f::F, u) where {F} = f(u)

0 commit comments

Comments
 (0)