@@ -55,9 +55,13 @@ const TERM_DOCS = Dict(
5555 :AbsNorm => doc " ``\| \f rac{\p artial u}{\p artial t} \| \l eq abstol``"
5656)
5757
58- const __TERM_INTERNALNORM_DOCS = " where `internalnorm` is the norm to use for the \
59- termination condition. Special handling is done for \
60- `norm(_, 2)`, `norm(_, Inf)`, and `maximum(abs, _)`."
58+ const __TERM_INTERNALNORM_DOCS = """
59+ where `internalnorm` is the norm to use for the termination condition. Special handling is
60+ done for `norm(_, 2)`, `norm(_, Inf)`, and `maximum(abs, _)`.
61+
62+ Default is left as `nothing`, which allows upstream frameworks to choose the correct norm
63+ based on the problem type. If directly using the `init` API, a proper norm must be
64+ provided"""
6165
6266for name in (:Rel , :Abs )
6367 struct_name = Symbol (name, :TerminationMode )
@@ -85,16 +89,14 @@ for name in (:Norm, :RelNorm, :AbsNorm)
8589
8690 ## Constructor
8791
88- $($ struct_name) (internalnorm = Base.Fix1(maximum, abs) )
92+ $($ struct_name) (internalnorm = nothing )
8993
9094 $($ __TERM_INTERNALNORM_DOCS) .
9195 """
9296 @concrete struct $ (struct_name){F} <: AbstractNonlinearTerminationMode
9397 internalnorm
9498
95- function $ (struct_name)(f:: F = Base. Fix1 (maximum, abs)) where {F}
96- return new {__norm_type(f), F} (f)
97- end
99+ $ (struct_name)(f:: F = nothing ) where {F} = new {__norm_type(f), F} (f)
98100 end
99101 end
100102end
@@ -118,10 +120,9 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
118120
119121 ## Constructor
120122
121- $($ struct_name) (internalnorm = Base.Fix1(maximum, abs);
122- protective_threshold = nothing, patience_steps = 100,
123- patience_objective_multiplier = 3, min_max_factor = 1.3,
124- max_stalled_steps = nothing)
123+ $($ struct_name) (internalnorm = nothing; protective_threshold = nothing,
124+ patience_steps = 100, patience_objective_multiplier = 3,
125+ min_max_factor = 1.3, max_stalled_steps = nothing)
125126
126127 $($ __TERM_INTERNALNORM_DOCS) .
127128 """
@@ -133,10 +134,9 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
133134 min_max_factor
134135 max_stalled_steps:: T
135136
136- function $ (struct_name)(f:: F = Base. Fix1 (maximum, abs);
137- protective_threshold = nothing , patience_steps = 100 ,
138- patience_objective_multiplier = 3 , min_max_factor = 1.3 ,
139- max_stalled_steps = nothing ) where {F}
137+ function $ (struct_name)(f:: F = nothing ; protective_threshold = nothing ,
138+ patience_steps = 100 , patience_objective_multiplier = 3 ,
139+ min_max_factor = 1.3 , max_stalled_steps = nothing ) where {F}
140140 return new{__norm_type (f), typeof (max_stalled_steps), F,
141141 typeof (protective_threshold), typeof (patience_objective_multiplier),
142142 typeof (min_max_factor)}(f, protective_threshold, patience_steps,
@@ -199,10 +199,10 @@ function SciMLBase.init(du::Union{AbstractArray{T}, T}, u::Union{AbstractArray{T
199199 (ArrayInterface. can_setindex (u) ? copy (u) : u) : nothing
200200 if mode isa AbstractSafeNonlinearTerminationMode
201201 if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
202- initial_objective = mode. internalnorm ( du)
202+ initial_objective = __apply_termination_internalnorm ( mode. internalnorm, du)
203203 u0_norm = nothing
204204 else
205- initial_objective = mode. internalnorm ( du) /
205+ initial_objective = __apply_termination_internalnorm ( mode. internalnorm, du) /
206206 (__add_and_norm (mode. internalnorm, du, u) + eps (TT))
207207 u0_norm = mode. max_stalled_steps === nothing ? nothing : norm (u, 2 )
208208 end
@@ -255,9 +255,11 @@ function SciMLBase.reinit!(cache::NonlinearTerminationModeCache{dep_retcode}, du
255255 mode = get_termination_mode (cache)
256256 if mode isa AbstractSafeNonlinearTerminationMode
257257 if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
258- initial_objective = cache. mode. internalnorm (du)
258+ initial_objective = __apply_termination_internalnorm (
259+ cache. mode. internalnorm, du)
259260 else
260- initial_objective = cache. mode. internalnorm (du) /
261+ initial_objective = __apply_termination_internalnorm (
262+ cache. mode. internalnorm, du) /
261263 (__add_and_norm (cache. mode. internalnorm, du, u) + eps (TT))
262264 cache. max_stalled_steps != = nothing && (cache. u0_norm = norm (u_, 2 ))
263265 end
@@ -292,10 +294,10 @@ function (cache::NonlinearTerminationModeCache{dep_retcode})(
292294 mode:: AbstractSafeNonlinearTerminationMode ,
293295 du, u, uprev, args... ) where {dep_retcode}
294296 if mode isa AbsSafeTerminationMode || mode isa AbsSafeBestTerminationMode
295- objective = mode. internalnorm ( du)
297+ objective = __apply_termination_internalnorm ( mode. internalnorm, du)
296298 criteria = cache. abstol
297299 else
298- objective = mode. internalnorm ( du) /
300+ objective = __apply_termination_internalnorm ( mode. internalnorm, du) /
299301 (__add_and_norm (mode. internalnorm, du, u) + eps (cache. abstol))
300302 criteria = cache. reltol
301303 end
@@ -427,19 +429,20 @@ function check_convergence(::AbsTerminationMode, duₙ, uₙ, uₙ₋₁, abstol
427429end
428430
429431function check_convergence (mode:: NormTerminationMode , duₙ, uₙ, uₙ₋₁, abstol, reltol)
430- du_norm = mode. internalnorm ( duₙ)
432+ du_norm = __apply_termination_internalnorm ( mode. internalnorm, duₙ)
431433 return (du_norm ≤ abstol) ||
432434 (du_norm ≤ reltol * __add_and_norm (mode. internalnorm, duₙ, uₙ))
433435end
434436function check_convergence (
435437 mode:: Union {
436438 RelNormTerminationMode, RelSafeTerminationMode, RelSafeBestTerminationMode},
437439 duₙ, uₙ, uₙ₋₁, abstol, reltol)
438- return mode. internalnorm (duₙ) ≤ reltol * __add_and_norm (mode. internalnorm, duₙ, uₙ)
440+ return __apply_termination_internalnorm (mode. internalnorm, duₙ) ≤
441+ reltol * __add_and_norm (mode. internalnorm, duₙ, uₙ)
439442end
440443function check_convergence (
441444 mode:: Union {AbsNormTerminationMode, AbsSafeTerminationMode,
442445 AbsSafeBestTerminationMode},
443446 duₙ, uₙ, uₙ₋₁, abstol, reltol)
444- return mode. internalnorm ( duₙ) ≤ abstol
447+ return __apply_termination_internalnorm ( mode. internalnorm, duₙ) ≤ abstol
445448end
0 commit comments