Skip to content

Commit 838ad80

Browse files
test: add tests for if-lifting
1 parent 7deb72b commit 838ad80

File tree

2 files changed

+111
-0
lines changed

2 files changed

+111
-0
lines changed

test/if_lifting.jl

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
using ModelingToolkit, OrdinaryDiffEq
2+
using ModelingToolkit: t_nounits as t, D_nounits as D, IfLifting, no_if_lift
3+
4+
@testset "Simple `abs(x)`" begin
5+
@mtkmodel SimpleAbs begin
6+
@variables begin
7+
x(t)
8+
y(t)
9+
end
10+
@equations begin
11+
D(x) ~ abs(y)
12+
y ~ sin(t)
13+
end
14+
end
15+
@named sys = SimpleAbs()
16+
ss1 = structural_simplify(sys)
17+
@test length(equations(ss1)) == 1
18+
ss2 = structural_simplify(sys, additional_passes = [IfLifting])
19+
@test length(equations(ss2)) == 1
20+
@test length(parameters(ss2)) == 1
21+
@test operation(only(equations(ss2)).rhs) === ifelse
22+
23+
discvar = only(parameters(ss2))
24+
prob2 = ODEProblem(ss2, [x => 0.0], (0.0, 5.0))
25+
sol2 = solve(prob2, Tsit5())
26+
@test count(isapprox(pi), sol2.t) == 2
27+
@test any(isapprox(pi), sol2.discretes[1].t)
28+
@test !sol2[discvar][1]
29+
@test sol2[discvar][end]
30+
31+
_t = pi + 1.0
32+
# x(t) = 1 - cos(t) in [0, pi)
33+
# x(t) = 3 + cos(t) in [pi, 2pi)
34+
_trueval = 3 + cos(_t)
35+
@test !isapprox(sol1(_t)[1], _trueval; rtol = 1e-3)
36+
@test isapprox(sol2(_t)[1], _trueval; rtol = 1e-3)
37+
end
38+
39+
@testset "Big test case" begin
40+
@mtkmodel BigModel begin
41+
@variables begin
42+
x(t)
43+
y(t)
44+
z(t)
45+
c(t)::Bool
46+
w(t)
47+
q(t)
48+
r(t)
49+
end
50+
@parameters begin
51+
p
52+
end
53+
@equations begin
54+
# ifelse, max, min
55+
D(x) ~ ifelse(c, max(x, y), min(x, y))
56+
# discrete observed
57+
c ~ x <= y
58+
# observed should also get if-lifting
59+
y ~ abs(sin(t))
60+
# should be ignored
61+
D(z) ~ no_if_lift(ifelse(x < y, x, y))
62+
# ignore time-independent ifelse
63+
D(w) ~ ifelse(p < 3, 1.0, 2.0)
64+
# all the boolean operators
65+
D(q) ~ ifelse((x < 1) & ((y < 0.5) | ifelse(y > 0.8, c, !c)), 1.0, 2.0)
66+
# don't touch time-independent condition, but modify time-dependent branches
67+
D(r) ~ ifelse(p < 2, abs(x), max(y, 0.9))
68+
end
69+
end
70+
71+
@named sys = BigModel()
72+
ss = structural_simplify(sys, additional_passes = [IfLifting])
73+
74+
ps = parameters(ss)
75+
@test length(ps) == 9
76+
eqs = equations(ss)
77+
obs = observed(ss)
78+
79+
@testset "no_if_lift is untouched" begin
80+
idx = findfirst(eq -> isequal(eq.lhs, D(ss.z)), eqs)
81+
eq = eqs[idx]
82+
@test isequal(eq.rhs, no_if_lift(ifelse(ss.x < ss.y, ss.x, ss.y)))
83+
end
84+
@testset "time-independent ifelse is untouched" begin
85+
idx = findfirst(eq -> isequal(eq.lhs, D(ss.w)), eqs)
86+
eq = eqs[idx]
87+
@test operation(arguments(eq.rhs)[1]) === Base.:<
88+
end
89+
@testset "time-dependent branch of time-independent condition is modified" begin
90+
idx = findfirst(eq -> isequal(eq.lhs, D(ss.r)), eqs)
91+
eq = eqs[idx]
92+
@test operation(eq.rhs) === ifelse
93+
args = arguments(eq.rhs)
94+
@test operation(args[1]) == Base.:<
95+
@test operation(args[2]) === ifelse
96+
condvars = ModelingToolkit.vars(arguments(args[2])[1])
97+
@test length(condvars) == 1 && any(isequal(only(condvars)), ps)
98+
@test operation(args[3]) === ifelse
99+
condvars = ModelingToolkit.vars(arguments(args[3])[1])
100+
@test length(condvars) == 1 && any(isequal(only(condvars)), ps)
101+
end
102+
@testset "Observed variables are modified" begin
103+
idx = findfirst(eq -> isequal(eq.lhs, ss.c), obs)
104+
eq = obs[idx]
105+
@test operation(eq.rhs) === Base.:! && any(isequal(only(arguments(eq.rhs))), ps)
106+
idx = findfirst(eq -> isequal(eq.lhs, ss.y), obs)
107+
eq = obs[idx]
108+
@test operation(eq.rhs) === ifelse
109+
end
110+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,7 @@ end
8383
@safetestset "JumpSystem Test" include("jumpsystem.jl")
8484
@safetestset "print_tree" include("print_tree.jl")
8585
@safetestset "Constraints Test" include("constraints.jl")
86+
@safetestset "IfLifting Test" include("if_lifting.jl")
8687
end
8788
end
8889

0 commit comments

Comments
 (0)