Skip to content

Add support for an external synchronous compiler to discrete and hybrid systems #3399

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1207,6 +1207,7 @@ function namespace_expr(
O
end
end

_nonum(@nospecialize x) = x isa Num ? x.val : x

"""
Expand Down
11 changes: 10 additions & 1 deletion src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ function infer_clocks!(ci::ClockInference)
c = BitSet(c′)
idxs = intersect(c, inferred)
isempty(idxs) && continue
if !allequal(var_domain[i] for i in idxs)
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
display(fullvars[c′])
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
end
Expand Down Expand Up @@ -155,6 +155,9 @@ function split_system(ci::ClockInference{S}) where {S}
cid_to_var = Vector{Int}[]
# cid_counter = number of clocks
cid_counter = Ref(0)

# populates clock_to_id and id_to_clock
# checks if there is a continuous_id (for some reason? clock to id does this too)
for (i, d) in enumerate(eq_domain)
cid = let cid_counter = cid_counter, id_to_clock = id_to_clock,
continuous_id = continuous_id
Expand All @@ -174,9 +177,13 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_eq, i, cid)
end
continuous_id = continuous_id[]
# for each clock partition what are the input (indexes/vars)
input_idxs = map(_ -> Int[], 1:cid_counter[])
inputs = map(_ -> Any[], 1:cid_counter[])
# var_domain corresponds to fullvars/all variables in the system
nvv = length(var_domain)
# put variables into the right clock partition
# keep track of inputs to each partition
for i in 1:nvv
d = var_domain[i]
cid = get(clock_to_id, d, 0)
Expand All @@ -190,6 +197,7 @@ function split_system(ci::ClockInference{S}) where {S}
resize_or_push!(cid_to_var, i, cid)
end

# breaks the system up into a continous and 0 or more discrete systems
tss = similar(cid_to_eq, S)
for (id, ieqs) in enumerate(cid_to_eq)
ts_i = system_subset(ts, ieqs)
Expand All @@ -199,6 +207,7 @@ function split_system(ci::ClockInference{S}) where {S}
end
tss[id] = ts_i
end
# put the continous system at the back
if continuous_id != 0
tss[continuous_id], tss[end] = tss[end], tss[continuous_id]
inputs[continuous_id], inputs[end] = inputs[end], inputs[continuous_id]
Expand Down
4 changes: 3 additions & 1 deletion src/systems/imperative_affect.jl
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,9 @@ function compile_functional_affect(
upd_vals = user_affect(upd_component_array, obs_component_array, ctx, integ)

# write the new values back to the integrator
_generated_writeback(integ, upd_funs, upd_vals)
if !isnothing(upd_vals)
_generated_writeback(integ, upd_funs, upd_vals)
end

reset_jumps && reset_aggregated_jumps!(integ)
end
Expand Down
33 changes: 33 additions & 0 deletions src/systems/state_machines.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,3 +153,36 @@ entry

When used in a finite state machine, this operator returns `true` if the queried state is active and false otherwise.
""" activeState

function vars!(vars, O::Transition; op = Differential)
vars!(vars, O.from)
vars!(vars, O.to)
vars!(vars, O.cond; op)
return vars
end
function vars!(vars, O::InitialState; op = Differential)
vars!(vars, O.s; op)
return vars
end
function vars!(vars, O::StateMachineOperator; op = Differential)
error("Unhandled state machine operator")
end

function namespace_expr(
O::Transition, sys, n = nameof(sys); ivs = independent_variables(sys))
return Transition(
O.from === nothing ? O.from : renamespace(sys, O.from),
O.to === nothing ? O.to : renamespace(sys, O.to),
O.cond === nothing ? O.cond : namespace_expr(O.cond, sys),
O.immediate, O.reset, O.synchronize, O.priority
)
end

function namespace_expr(
O::InitialState, sys, n = nameof(sys); ivs = independent_variables(sys))
return InitialState(O.s === nothing ? O.s : renamespace(sys, O.s))
end

function namespace_expr(O::StateMachineOperator, sys, n = nameof(sys); kwargs...)
error("Unhandled state machine operator")
end
15 changes: 11 additions & 4 deletions src/systems/systems.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ function mtkcompile(
isscheduled(sys) && throw(RepeatedStructuralSimplificationError())
newsys′ = __mtkcompile(sys; simplify,
allow_symbolic, allow_parameter, conservative, fully_determined,
inputs, outputs, disturbance_inputs,
inputs, outputs, disturbance_inputs, additional_passes,
kwargs...)
if newsys′ isa Tuple
@assert length(newsys′) == 2
Expand Down Expand Up @@ -75,12 +75,13 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
return simplify_optimization_system(sys; kwargs..., sort_eqs, simplify)
end

sys, statemachines = extract_top_level_statemachines(sys)
sys = expand_connections(sys)
state = TearingState(sys; sort_eqs)
state = TearingState(sys)
append!(state.statemachines, statemachines)

@unpack structure, fullvars = state
@unpack graph, var_to_diff, var_types = structure
eqs = equations(state)
brown_vars = Int[]
new_idxs = zeros(Int, length(var_types))
idx = 0
Expand All @@ -98,7 +99,8 @@ function __mtkcompile(sys::AbstractSystem; simplify = false,
Is = Int[]
Js = Int[]
vals = Num[]
new_eqs = copy(eqs)
make_eqs_zero_equals!(state)
new_eqs = copy(equations(state))
dvar2eq = Dict{Any, Int}()
for (v, dv) in enumerate(var_to_diff)
dv === nothing && continue
Expand Down Expand Up @@ -291,3 +293,8 @@ function map_variables_to_equations(sys::AbstractSystem; rename_dummy_derivative

return mapping
end

"""
Mark whether an extra pass `p` can support compiling discrete systems.
"""
discrete_compile_pass(p) = false
154 changes: 134 additions & 20 deletions src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -203,18 +203,37 @@ end
mutable struct TearingState{T <: AbstractSystem} <: AbstractTearingState{T}
"""The system of equations."""
sys::T
original_eqs::Vector{Equation}
"""The set of variables of the system."""
fullvars::Vector{BasicSymbolic}
structure::SystemStructure
extra_eqs::Vector
param_derivative_map::Dict{BasicSymbolic, Any}
statemachines::Vector{T}
end

TransformationState(sys::AbstractSystem) = TearingState(sys)
function system_subset(ts::TearingState, ieqs::Vector{Int})
eqs = equations(ts)
@set! ts.original_eqs = ts.original_eqs[ieqs]
@set! ts.sys.eqs = eqs[ieqs]
@set! ts.structure = system_subset(ts.structure, ieqs)
if all(eq -> eq.rhs isa StateMachineOperator, get_eqs(ts.sys))
names = Symbol[]
for eq in get_eqs(ts.sys)
if eq.lhs isa Transition
push!(names, first(namespace_hierarchy(nameof(eq.rhs.from))))
push!(names, first(namespace_hierarchy(nameof(eq.rhs.to))))
elseif eq.lhs isa InitialState
push!(names, first(namespace_hierarchy(nameof(eq.rhs.s))))
else
error("Unhandled state machine operator")
end
end
@set! ts.statemachines = filter(x -> nameof(x) in names, ts.statemachines)
else
@set! ts.statemachines = eltype(ts.statemachines)[]
end
ts
end

Expand Down Expand Up @@ -268,14 +287,58 @@ function symbolic_contains(var, set)
all(x -> x in set, Symbolics.scalarize(var))
end

"""
$(TYPEDSIGNATURES)

Descend through the system hierarchy and look for statemachines. Remove equations from
the inner statemachine systems. Return the new `sys` and an array of top-level
statemachines.
"""
function extract_top_level_statemachines(sys::AbstractSystem)
eqs = get_eqs(sys)

if !isempty(eqs) && all(eq -> eq.lhs isa StateMachineOperator, eqs)
# top-level statemachine
with_removed = @set sys.systems = map(remove_child_equations, get_systems(sys))
return with_removed, [sys]
elseif !isempty(eqs) && any(eq -> eq.lhs isa StateMachineOperator, eqs)
# error: can't mix
error("Mixing statemachine equations and standard equations in a top-level statemachine is not allowed.")
else
# descend
subsystems = get_systems(sys)
newsubsystems = eltype(subsystems)[]
statemachines = eltype(subsystems)[]
for subsys in subsystems
newsubsys, sub_statemachines = extract_top_level_statemachines(subsys)
push!(newsubsystems, newsubsys)
append!(statemachines, sub_statemachines)
end
@set! sys.systems = newsubsystems
return sys, statemachines
end
end

"""
$(TYPEDSIGNATURES)

Return `sys` with all equations (including those in subsystems) removed.
"""
function remove_child_equations(sys::AbstractSystem)
@set! sys.eqs = eltype(get_eqs(sys))[]
@set! sys.systems = map(remove_child_equations, get_systems(sys))
return sys
end

function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
# flatten system
sys = flatten(sys)
sys = process_parameter_equations(sys)
ivs = independent_variables(sys)
iv = length(ivs) == 1 ? ivs[1] : nothing
# flatten array equations
eqs = flatten_equations(equations(sys))
# scalarize array equations, without scalarizing arguments to registered functions
original_eqs = flatten_equations(copy(equations(sys)))
eqs = copy(original_eqs)
neqs = length(eqs)
param_derivative_map = Dict{BasicSymbolic, Any}()
# * Scalarize unknowns
Expand Down Expand Up @@ -331,9 +394,16 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
# change the equation if the RHS is `missing` so the rest of this loop works
eq = 0.0 ~ coalesce(eq.rhs, 0.0)
end
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
if !_iszero(eq.lhs)
is_statemachine_equation = false
if eq.lhs isa StateMachineOperator
is_statemachine_equation = true
eq = eq
rhs = eq.rhs
elseif _iszero(eq.lhs)
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
else
lhs = quick_cancel ? quick_cancel_expr(eq.lhs) : eq.lhs
rhs = quick_cancel ? quick_cancel_expr(eq.rhs) : eq.rhs
eq = 0 ~ rhs - lhs
end
empty!(varsbuf)
Expand Down Expand Up @@ -397,8 +467,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
addvar!(v, VARIABLE)
end
end

if isalgeq
if isalgeq || is_statemachine_equation
eqs[i] = eq
else
eqs[i] = eqs[i].lhs ~ rhs
Expand All @@ -415,6 +484,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
end
end
eqs = eqs[eqs_to_retain]
original_eqs = original_eqs[eqs_to_retain]
neqs = length(eqs)
symbolic_incidence = symbolic_incidence[eqs_to_retain]

Expand All @@ -423,6 +493,7 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)
# depending on order due to NP-completeness of tearing.
sortidxs = Base.sortperm(eqs, by = string)
eqs = eqs[sortidxs]
original_eqs = original_eqs[sortidxs]
symbolic_incidence = symbolic_incidence[sortidxs]
end

Expand Down Expand Up @@ -513,11 +584,10 @@ function TearingState(sys; quick_cancel = false, check = true, sort_eqs = true)

eq_to_diff = DiffGraph(nsrcs(graph))

ts = TearingState(sys, fullvars,
ts = TearingState(sys, original_eqs, fullvars,
SystemStructure(complete(var_to_diff), complete(eq_to_diff),
complete(graph), nothing, var_types, false),
Any[], param_derivative_map)

Any[], param_derivative_map, typeof(sys)[])
return ts
end

Expand Down Expand Up @@ -696,29 +766,73 @@ function Base.show(io::IO, mime::MIME"text/plain", ms::MatchedSystemStructure)
printstyled(io, " SelectedState")
end

function make_eqs_zero_equals!(ts::TearingState)
neweqs = map(enumerate(get_eqs(ts.sys))) do kvp
i, eq = kvp
isalgeq = true
for j in 𝑠neighbors(ts.structure.graph, i)
isalgeq &= invview(ts.structure.var_to_diff)[j] === nothing
end
if isalgeq
return 0 ~ eq.rhs - eq.lhs
else
return eq
end
end
copyto!(get_eqs(ts.sys), neweqs)
end

function mtkcompile!(state::TearingState; simplify = false,
check_consistency = true, fully_determined = true, warn_initialize_determined = true,
inputs = Any[], outputs = Any[],
disturbance_inputs = Any[],
kwargs...)
# split_system returns one or two systems and the inputs for each
# mod clock inference to be binary
# if it's continous keep going, if not then error unless given trait impl in additional passes
ci = ModelingToolkit.ClockInference(state)
ci = ModelingToolkit.infer_clocks!(ci)
time_domains = merge(Dict(state.fullvars .=> ci.var_domain),
Dict(default_toterm.(state.fullvars) .=> ci.var_domain))
tss, clocked_inputs, continuous_id, id_to_clock = ModelingToolkit.split_system(ci)
if continuous_id == 0
# do a trait check here - handle fully discrete system
additional_passes = get(kwargs, :additional_passes, nothing)
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
# take the first discrete compilation pass given for now
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
return discrete_compile(tss, clocked_inputs, ci)
end
throw(HybridSystemNotSupportedException("""
Discrete systems with multiple clocks are not supported with the standard \
MTK compiler.
"""))
end
if length(tss) > 1
if continuous_id == 0
throw(HybridSystemNotSupportedException("""
Discrete systems with multiple clocks are not supported with the standard \
MTK compiler.
"""))
else
throw(HybridSystemNotSupportedException("""
Hybrid continuous-discrete systems are currently not supported with \
the standard MTK compiler. This system requires JuliaSimCompiler.jl, \
see https://help.juliahub.com/juliasimcompiler/stable/
"""))
make_eqs_zero_equals!(tss[continuous_id])
# simplify as normal
sys = _mtkcompile!(tss[continuous_id]; simplify,
inputs = [inputs; clocked_inputs[continuous_id]], outputs, disturbance_inputs,
check_consistency, fully_determined,
kwargs...)
additional_passes = get(kwargs, :additional_passes, nothing)
if !isnothing(additional_passes) && any(discrete_compile_pass, additional_passes)
discrete_pass_idx = findfirst(discrete_compile_pass, additional_passes)
discrete_compile = additional_passes[discrete_pass_idx]
deleteat!(additional_passes, discrete_pass_idx)
# in the case of a hybrid system, the discrete_compile pass should take the currents of sys.discrete_subsystems
# and modifies discrete_subsystems to bea tuple of the io and anything else, while adding or manipulating the rest of sys as needed
return discrete_compile(
sys, tss[[i for i in eachindex(tss) if i != continuous_id]],
clocked_inputs, ci, id_to_clock)
end
throw(HybridSystemNotSupportedException("""
Hybrid continuous-discrete systems are currently not supported with \
the standard MTK compiler. This system requires JuliaSimCompiler.jl, \
see https://help.juliahub.com/juliasimcompiler/stable/
"""))
end
if get_is_discrete(state.sys) ||
continuous_id == 1 && any(Base.Fix2(isoperator, Shift), state.fullvars)
Expand Down
Loading
Loading