Skip to content

Commit d568df3

Browse files
fix: fix unscalarized array passed to discrete_parameters
1 parent b8e2ba1 commit d568df3

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/systems/callbacks.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -904,7 +904,12 @@ function compile_equational_affect(
904904
obseqs, Dict([p => unPre(p) for p in parameters(affsys)]))
905905
rhss = map(x -> x.rhs, update_eqs)
906906
lhss = map(x -> x.lhs, update_eqs)
907-
is_p = [lhs in Set(ps_to_update) for lhs in lhss]
907+
update_ps_set = Set(ps_to_update)
908+
is_p = map(lhss) do lhs
909+
lhs in update_ps_set ||
910+
iscall(lhs) && operation(lhs) === getindex &&
911+
arguments(lhs)[1] in update_ps_set
912+
end
908913
is_u = [lhs in Set(dvs_to_update) for lhs in lhss]
909914
dvs = unknowns(sys)
910915
ps = parameters(sys)

test/symbolic_events.jl

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,3 +1495,43 @@ end
14951495
@test v_sol[v] [vini, -vini, vini]
14961496
@test M_sol[M] [Mini, -Mini, Mini]
14971497
end
1498+
1499+
@testset "Issue#3990: Scalarized array passed to `discrete_parameters` of symbolic affect" begin
1500+
N = 2
1501+
@parameters v(t)[1:N]
1502+
@parameters M(t)[1:N, 1:N]
1503+
1504+
@variables x(t)
1505+
1506+
Mini = rand(N, N) ./ (N^2)
1507+
vini = vec(sum(Mini, dims = 1))
1508+
1509+
v_eq = [D(x) ~ x * Symbolics.scalarize(sum(v))]
1510+
M_eq = [D(x) ~ x * Symbolics.scalarize(sum(M))]
1511+
1512+
v_event = ModelingToolkit.SymbolicDiscreteCallback(
1513+
1.0,
1514+
[v ~ -Pre(v)],
1515+
discrete_parameters = collect(v)
1516+
)
1517+
1518+
M_event = ModelingToolkit.SymbolicDiscreteCallback(
1519+
1.0,
1520+
[M ~ -Pre(M)],
1521+
discrete_parameters = vec(collect(M))
1522+
)
1523+
1524+
@mtkcompile v_sys = System(v_eq, t; discrete_events = v_event)
1525+
@mtkcompile M_sys = System(M_eq, t; discrete_events = M_event)
1526+
1527+
u0p0_map = Dict(x => 1.0, M => Mini, v => vini)
1528+
1529+
v_prob = ODEProblem(v_sys, u0p0_map, (0.0, 2.5))
1530+
M_prob = ODEProblem(M_sys, u0p0_map, (0.0, 2.5))
1531+
1532+
v_sol = solve(v_prob, Tsit5())
1533+
M_sol = solve(M_prob, Tsit5())
1534+
1535+
@test v_sol[v] [vini, -vini, vini]
1536+
@test M_sol[M] [Mini, -Mini, Mini]
1537+
end

0 commit comments

Comments
 (0)