Skip to content

Commit 7346594

Browse files
committed
Fix the testing
1 parent b9ae9f8 commit 7346594

File tree

3 files changed

+29
-7
lines changed

3 files changed

+29
-7
lines changed

src/termination_conditions.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ struct SimpleNonlinearSolveTerminationMode <: AbstractNonlinearTerminationMode
4242
end
4343
end
4444

45+
@inline set_termination_mode_internalnorm(mode, ::F) where {F} = mode
46+
4547
@inline __norm_type(::typeof(Base.Fix2(norm, Inf))) = :Inf
4648
@inline __norm_type(::typeof(Base.Fix1(maximum, abs))) = :Inf
4749
@inline __norm_type(::typeof(Base.Fix2(norm, 2))) = :L2
@@ -98,6 +100,11 @@ for name in (:Norm, :RelNorm, :AbsNorm)
98100

99101
$(struct_name)(f::F = nothing) where {F} = new{__norm_type(f), F}(f)
100102
end
103+
104+
@inline function set_termination_mode_internalnorm(
105+
::$(struct_name), internalnorm::F) where {F}
106+
return $(struct_name)(internalnorm)
107+
end
101108
end
102109
end
103110

@@ -143,6 +150,13 @@ for norm_type in (:Rel, :Abs), safety in (:Safe, :SafeBest)
143150
patience_objective_multiplier, min_max_factor, max_stalled_steps)
144151
end
145152
end
153+
154+
@inline function set_termination_mode_internalnorm(
155+
mode::$(struct_name), internalnorm::F) where {F}
156+
return $(struct_name)(internalnorm; mode.protective_threshold,
157+
mode.patience_steps, mode.patience_objective_multiplier,
158+
mode.min_max_factor, mode.max_stalled_steps)
159+
end
146160
end
147161
end
148162

src/utils.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,10 @@ end
8383

8484
@inline function __norm_op(::typeof(Base.Fix2(norm, 2)), op::F, x, y) where {F}
8585
if __fast_scalar_indexing(x, y)
86-
return sqrt(sum(@closure((xᵢyᵢ)->(op(xᵢ, yᵢ)^2)), zip(x, y)))
86+
return sqrt(sum(@closure((xᵢyᵢ)->begin
87+
xᵢ, yᵢ = xᵢyᵢ
88+
return op(xᵢ, yᵢ)^2
89+
end), zip(x, y)))
8790
else
8891
return sqrt(mapreduce(@closure((xᵢ, yᵢ)->(op(xᵢ, yᵢ)^2)), +, x, y))
8992
end
@@ -104,7 +107,8 @@ end
104107

105108
@inline function __add_and_norm(::Nothing, x, y)
106109
Base.depwarn("Not specifying the internal norm of termination conditions has been \
107-
deprecated. Using inf-norm currently.", :__add_and_norm)
110+
deprecated. Using inf-norm currently.",
111+
:__add_and_norm)
108112
return __maximum_abs(+, x, y)
109113
end
110114
@inline __add_and_norm(::typeof(Base.Fix1(maximum, abs)), x, y) = __maximum_abs(+, x, y)
@@ -113,7 +117,8 @@ end
113117

114118
@inline function __apply_termination_internalnorm(::Nothing, u)
115119
Base.depwarn("Not specifying the internal norm of termination conditions has been \
116-
deprecated. Using inf-norm currently.", :__apply_termination_internalnorm)
120+
deprecated. Using inf-norm currently.",
121+
:__apply_termination_internalnorm)
117122
return __apply_termination_internalnorm(Base.Fix1(maximum, abs), u)
118123
end
119124
@inline __apply_termination_internalnorm(f::F, u) where {F} = f(u)

test/termination_conditions.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,22 @@
1-
using BenchmarkTools, DiffEqBase, Test
1+
using BenchmarkTools, DiffEqBase, LinearAlgebra, Test
22

33
du = rand(4)
44
u = rand(4)
55
uprev = rand(4)
66

77
const TERMINATION_CONDITIONS = [
88
SteadyStateDiffEqTerminationMode(), SimpleNonlinearSolveTerminationMode(),
9-
NormTerminationMode(), RelTerminationMode(), RelNormTerminationMode(),
9+
RelTerminationMode(), NormTerminationMode(), RelNormTerminationMode(),
1010
AbsTerminationMode(), AbsNormTerminationMode(), RelSafeTerminationMode(),
1111
AbsSafeTerminationMode(), RelSafeBestTerminationMode(), AbsSafeBestTerminationMode()
1212
]
1313

1414
@testset "Termination Conditions: Allocations" begin
1515
@testset "Mode: $(tcond)" for tcond in TERMINATION_CONDITIONS
16-
@test (@ballocated DiffEqBase.check_convergence($tcond, $du, $u, $uprev, 1e-3,
17-
1e-3)) == 0
16+
for nfn in (Base.Fix1(maximum, abs), Base.Fix2(norm, 2), Base.Fix2(norm, Inf))
17+
tcond = DiffEqBase.set_termination_mode_internalnorm(tcond, nfn)
18+
@test (@ballocated DiffEqBase.check_convergence($tcond, $du, $u, $uprev, 1e-3,
19+
1e-3)) == 0
20+
end
1821
end
1922
end

0 commit comments

Comments
 (0)