Skip to content

Commit 84dbc1c

Browse files
fix: pass operating point to ImplicitDiscreteProblem in generate_equational_affect
1 parent 58e917f commit 84dbc1c

File tree

8 files changed

+55
-9
lines changed

8 files changed

+55
-9
lines changed

src/problems/daeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ end
7272
eval_module, check_compatibility, implicit_dae = true, expression, kwargs...)
7373

7474
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
75-
kwargs...)
75+
op, kwargs...)
7676

7777
diffvars = collect_differential_variables(sys)
7878
sts = unknowns(sys)

src/problems/ddeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ end
6666
end
6767

6868
kwargs = process_kwargs(
69-
sys; expression, callback, eval_expression, eval_module, kwargs...)
69+
sys; expression, callback, eval_expression, eval_module, op, kwargs...)
7070
args = (; f, u0, h, tspan, p)
7171

7272
return maybe_codegen_scimlproblem(expression, DDEProblem{iip}, args; kwargs...)

src/problems/jumpproblem.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,8 @@
8080
end
8181

8282
# handle events, making sure to reset aggregators in the generated affect functions
83-
cbs = process_events(sys; callback, eval_expression, eval_module, reset_jumps = true)
83+
cbs = process_events(
84+
sys; callback, eval_expression, eval_module, op, reset_jumps = true)
8485

8586
if rng !== nothing
8687
kwargs = (; kwargs..., rng)

src/problems/odeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ end
7575
eval_module, expression, check_compatibility, kwargs...)
7676

7777
kwargs = process_kwargs(
78-
sys; expression, callback, eval_expression, eval_module, kwargs...)
78+
sys; expression, callback, eval_expression, eval_module, op, kwargs...)
7979

8080
ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem())
8181
args = (; f, u0, tspan, p, ptype)

src/problems/sddeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ end
6868
end
6969

7070
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
71-
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, kwargs...)
71+
kwargs = process_kwargs(sys; callback, eval_expression, eval_module, op, kwargs...)
7272

7373
if expression == Val{true}
7474
g = :(f.g)

src/problems/sdeproblem.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ end
7878

7979
noise, noise_rate_prototype = calculate_noise_and_rate_prototype(sys, u0; sparsenoise)
8080
kwargs = process_kwargs(sys; expression, callback, eval_expression, eval_module,
81-
kwargs...)
81+
op, kwargs...)
8282

8383
args = (; f, u0, tspan, p)
8484
kwargs = (; noise, noise_rate_prototype, kwargs...)

src/systems/callbacks.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -798,16 +798,34 @@ function add_integrator_header(
798798
expr.body)
799799
end
800800

801+
function default_operating_point(affsys::AffectSystem)
802+
sys = system(affsys)
803+
804+
op = Dict(unknowns(sys) .=> 0.0)
805+
for p in parameters(sys)
806+
T = symtype(p)
807+
if T <: Number
808+
op[p] = false
809+
elseif T <: Array{<:Real} && is_sized_array_symbolic(p)
810+
op[p] = zeros(size(p))
811+
end
812+
end
813+
return op
814+
end
815+
801816
"""
802817
Compile an affect defined by a set of equations. Systems with algebraic equations will solve implicit discrete problems to obtain their next state. Systems without will generate functions that perform explicit updates.
803818
"""
804819
function compile_equational_affect(
805820
aff::Union{AffectSystem, Vector{Equation}}, sys; reset_jumps = false,
806-
eval_expression = false, eval_module = @__MODULE__, kwargs...)
821+
eval_expression = false, eval_module = @__MODULE__, op = nothing, kwargs...)
807822
if aff isa AbstractVector
808823
aff = make_affect(
809824
aff; iv = get_iv(sys), warn_no_algebraic = false)
810825
end
826+
if op === nothing
827+
op = default_operating_point(aff)
828+
end
811829
affsys = system(aff)
812830
ps_to_update = discretes(aff)
813831
dvs_to_update = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
@@ -872,10 +890,10 @@ function compile_equational_affect(
872890
p_getter = getsym(affsys, ps_to_update)
873891

874892
affprob = ImplicitDiscreteProblem(
875-
affsys, Pair[unknowns(affsys) .=> 0; parameters(affsys) .=> 0],
893+
affsys, op,
876894
(0, 0);
877895
build_initializeprob = false, check_length = false, eval_expression,
878-
eval_module, check_compatibility = false)
896+
eval_module, check_compatibility = false, kwargs...)
879897

880898
function implicit_affect!(integ)
881899
new_u0 = affu_getter(integ)

test/symbolic_events.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1348,3 +1348,30 @@ end
13481348
@test SciMLBase.successful_retcode(sol)
13491349
@test sol[inner.p][end] 1.0
13501350
end
1351+
1352+
mutable struct ParamTest
1353+
y::Any
1354+
end
1355+
1356+
@testset "callable parameter and symbolic affect" begin
1357+
(pt::ParamTest)(x) = pt.y - x
1358+
1359+
p1 = ParamTest(1)
1360+
tp1 = typeof(p1)
1361+
@parameters (p_1::tp1)(..) = p1
1362+
@parameters p2(t) = 1.0
1363+
@variables x(t) = 0.0
1364+
@variables x2(t)
1365+
event = [0.5] => [p2 ~ Pre(t)]
1366+
1367+
eq = [
1368+
D(x) ~ p2,
1369+
x2 ~ p_1(x)
1370+
]
1371+
@mtkcompile sys = ODESystem(eq, t, [x, x2], [p_1, p2], discrete_events = [event])
1372+
1373+
prob = ODEProblem(sys, [], (0.0, 1.0))
1374+
sol = solve(prob)
1375+
@test SciMLBase.successful_retcode(sol)
1376+
@test sol[x, end]1.0 atol=1e-6
1377+
end

0 commit comments

Comments
 (0)