Skip to content

Commit f35c378

Browse files
Merge pull request #3672 from AayushSabharwal/as/infinite-recursion
[backport-v9] fix: fix infinite recursion in `full_equations`
2 parents 250e7be + 0e829d9 commit f35c378

File tree

2 files changed

+36
-18
lines changed

2 files changed

+36
-18
lines changed

src/structural_transformation/symbolics_tearing.jl

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -748,12 +748,14 @@ function update_simplified_system!(
748748
unknowns = [unknowns; extra_unknowns]
749749
@set! sys.unknowns = unknowns
750750

751-
obs, subeqs, deps = cse_and_array_hacks(
752-
sys, obs, solved_eqs, unknowns, neweqs; cse = cse_hack, array = array_hack)
751+
obs = cse_and_array_hacks(
752+
sys, obs, unknowns, neweqs; cse = cse_hack, array = array_hack)
753753

754+
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
755+
for i in 1:length(solved_eqs)]
754756
@set! sys.eqs = neweqs
755757
@set! sys.observed = obs
756-
@set! sys.substitutions = Substitutions(subeqs, deps)
758+
@set! sys.substitutions = Substitutions(solved_eqs, deps)
757759

758760
# Only makes sense for time-dependent
759761
# TODO: generalize to SDE
@@ -850,7 +852,7 @@ if all `p[i]` are present and the unscalarized form is used in any equation (obs
850852
not) we first count the number of times the scalarized form of each observed variable
851853
occurs in observed equations (and unknowns if it's split).
852854
"""
853-
function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, array = true)
855+
function cse_and_array_hacks(sys, obs, unknowns, neweqs; cse = true, array = true)
854856
# HACK 1
855857
# mapping of rhs to temporary CSE variable
856858
# `f(...) => tmpvar` in above example
@@ -878,7 +880,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
878880
tempeq = tempvar ~ rhs_arr
879881
rhs_to_tempvar[rhs_arr] = tempvar
880882
push!(obs, tempeq)
881-
push!(subeqs, tempeq)
882883
end
883884

884885
# getindex_wrapper is used because `observed2graph` treats `x` and `x[i]` as different,
@@ -887,10 +888,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
887888
neweq = lhs ~ getindex_wrapper(
888889
rhs_to_tempvar[rhs_arr], Tuple(arguments(rhs)[2:end]))
889890
obs[i] = neweq
890-
subeqi = findfirst(isequal(eq), subeqs)
891-
if subeqi !== nothing
892-
subeqs[subeqi] = neweq
893-
end
894891
end
895892
# end HACK 1
896893

@@ -920,7 +917,6 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
920917
tempeq = tempvar ~ rhs_arr
921918
rhs_to_tempvar[rhs_arr] = tempvar
922919
push!(obs, tempeq)
923-
push!(subeqs, tempeq)
924920
end
925921
# don't need getindex_wrapper, but do it anyway to know that this
926922
# hack took place
@@ -960,15 +956,8 @@ function cse_and_array_hacks(sys, obs, subeqs, unknowns, neweqs; cse = true, arr
960956
push!(obs_arr_eqs, arrvar ~ rhs)
961957
end
962958
append!(obs, obs_arr_eqs)
963-
append!(subeqs, obs_arr_eqs)
964-
965-
# need to re-sort subeqs
966-
subeqs = ModelingToolkit.topsort_equations(subeqs, [eq.lhs for eq in subeqs])
967-
968-
deps = Vector{Int}[i == 1 ? Int[] : collect(1:(i - 1))
969-
for i in 1:length(subeqs)]
970959

971-
return obs, subeqs, deps
960+
return obs
972961
end
973962

974963
function is_getindexed_array(rhs)

test/odesystem.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1755,3 +1755,32 @@ end
17551755
sol = solve(prob, Tsit5())
17561756
@test SciMLBase.successful_retcode(sol)
17571757
end
1758+
1759+
@testset "`full_equations` doesn't recurse infinitely" begin
1760+
code = """
1761+
using ModelingToolkit
1762+
using ModelingToolkit: t_nounits as t, D_nounits as D
1763+
@variables x(t)[1:3]=[0,0,1]
1764+
@variables u1(t)=0 u2(t)=0
1765+
y₁, y₂, y₃ = x
1766+
k₁, k₂, k₃ = 1,1,1
1767+
eqs = [
1768+
D(y₁) ~ -k₁*y₁ + k₃*y₂*y₃ + u1
1769+
D(y₂) ~ k₁*y₁ - k₃*y₂*y₃ - k₂*y₂^2 + u2
1770+
y₁ + y₂ + y₃ ~ 1
1771+
]
1772+
1773+
@named sys = ODESystem(eqs, t)
1774+
1775+
inputs = [u1, u2]
1776+
outputs = [y₁, y₂, y₃]
1777+
ss, = structural_simplify(sys, (inputs, []))
1778+
full_equations(ss)
1779+
"""
1780+
1781+
cmd = `$(Base.julia_cmd()) --project=$(@__DIR__) -e $code`
1782+
proc = run(cmd, stdin, stdout, stderr; wait = false)
1783+
sleep(120)
1784+
@test !process_running(proc)
1785+
kill(proc, Base.SIGKILL)
1786+
end

0 commit comments

Comments
 (0)