Skip to content

feat: rewrite clock inference to support polyadic synchronous operators #3808

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 9, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion src/clock.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
@data InferredClock begin
Inferred
InferredDiscrete
InferredDiscrete(Int)
end

const InferredTimeDomain = InferredClock.Type
using .InferredClock: Inferred, InferredDiscrete

function InferredClock.InferredDiscrete()
return InferredDiscrete(0)
end

Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)

struct VariableTimeDomain end
Expand Down
21 changes: 16 additions & 5 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,15 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl
is_transparent_operator(x) = is_transparent_operator(typeof(x))
is_transparent_operator(::Type) = false

"""
$(TYPEDSIGNATURES)

Trait to be implemented for operators which determines whether they are synchronous operators.
Synchronous operators must implement `input_timedomain` and `output_timedomain`.
"""
is_synchronous_operator(x) = is_synchronous_operator(typeof(x))
is_synchronous_operator(::Type) = false

"""
function SampleTime()

Expand Down Expand Up @@ -52,6 +61,7 @@ struct Shift <: Operator
end
Shift(steps::Int) = new(nothing, steps)
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
is_synchronous_operator(::Type{Shift}) = true
Base.nameof(::Shift) = :Shift
SymbolicUtils.isbinop(::Shift) = false

Expand Down Expand Up @@ -138,6 +148,7 @@ struct Sample <: Operator
Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete()) = new(clock)
end

is_synchronous_operator(::Type{Sample}) = true
is_transparent_operator(::Type{Sample}) = true

function Sample(arg::Real)
Expand Down Expand Up @@ -193,6 +204,7 @@ struct Hold <: Operator
end

is_transparent_operator(::Type{Hold}) = true
is_synchronous_operator(::Type{Hold}) = true

(D::Hold)(x) = Term{symtype(x)}(D, Any[x])
(D::Hold)(x::Num) = Num(D(value(x)))
Expand Down Expand Up @@ -314,12 +326,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
input_timedomain(op::Operator)

Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on.
Should return a tuple containing the time domain type for each argument to the operator.
"""
function input_timedomain(s::Shift, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete()
(InferredDiscrete(),)
end

"""
Expand All @@ -334,22 +347,20 @@ function output_timedomain(s::Shift, arg = nothing)
InferredDiscrete()
end

input_timedomain(::Sample, _ = nothing) = ContinuousClock()
input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),)
output_timedomain(s::Sample, _ = nothing) = s.clock

function input_timedomain(h::Hold, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete() # the Hold accepts any discrete
(InferredDiscrete(),) # the Hold accepts any discrete
end
output_timedomain(::Hold, _ = nothing) = ContinuousClock()

sampletime(op::Sample, _ = nothing) = sampletime(op.clock)
sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock)

changes_domain(op) = isoperator(op, Union{Sample, Hold})

function output_timedomain(x)
if isoperator(x, Operator)
return output_timedomain(operation(x), arguments(x)[])
Expand Down
187 changes: 154 additions & 33 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,17 @@
@data ClockVertex begin
Variable(Int)
Equation(Int)
Clock(SciMLBase.AbstractClock)
end

struct ClockInference{S}
"""Tearing state."""
ts::S
"""The time domain (discrete clock, continuous) of each equation."""
eq_domain::Vector{TimeDomain}
"""The output time domain (discrete clock, continuous) of each variable."""
var_domain::Vector{TimeDomain}
inference_graph::HyperGraph{ClockVertex.Type}
"""The set of variables with concrete domains."""
inferred::BitSet
end
Expand All @@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState)
var_domain[i] = d
end
end
ClockInference(ts, eq_domain, var_domain, inferred)
inference_graph = HyperGraph{ClockVertex.Type}()
for i in 1:nsrcs(graph)
add_vertex!(inference_graph, ClockVertex.Equation(i))
end
for i in 1:ndsts(graph)
varvert = ClockVertex.Variable(i)
add_vertex!(inference_graph, varvert)
v = ts.fullvars[i]
d = get_time_domain(v)
is_concrete_time_domain(d) || continue
dvert = ClockVertex.Clock(d)
add_vertex!(inference_graph, dvert)
add_edge!(inference_graph, (varvert, dvert))
end
ClockInference(ts, eq_domain, var_domain, inference_graph, inferred)
end

struct NotInferredTimeDomain end
Expand Down Expand Up @@ -75,47 +96,147 @@ end
Update the equation-to-time domain mapping by inferring the time domain from the variables.
"""
function infer_clocks!(ci::ClockInference)
@unpack ts, eq_domain, var_domain, inferred = ci
@unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
@unpack var_to_diff, graph = ts.structure
fullvars = get_fullvars(ts)
isempty(inferred) && return ci
# TODO: add a graph type to do this lazily
var_graph = SimpleGraph(ndsts(graph))
for eq in 𝑠vertices(graph)
vvs = 𝑠neighbors(graph, eq)
if !isempty(vvs)
fv, vs = Iterators.peel(vvs)
for v in vs
add_edge!(var_graph, fv, v)
end
end

var_to_idx = Dict(fullvars .=> eachindex(fullvars))

# all shifted variables have the same clock as the unshifted variant
for (i, v) in enumerate(fullvars)
iscall(v) || continue
operation(v) isa Shift || continue
unshifted = only(arguments(v))
add_edge!(inference_graph, (ClockVertex.Variable(i), ClockVertex.Variable(var_to_idx[unshifted])))
end
for v in vertices(var_to_diff)
if (v′ = var_to_diff[v]) !== nothing
add_edge!(var_graph, v, v′)

# preallocated buffers:
# variables in each equation
varsbuf = Set()
# variables in each argument to an operator
arg_varsbuf = Set()
# hyperedge for each equation
hyperedge = Set{ClockVertex.Type}()
# hyperedge for each argument to an operator
arg_hyperedge = Set{ClockVertex.Type}()
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}()

for (ieq, eq) in enumerate(equations(ts))
empty!(varsbuf)
empty!(hyperedge)
# get variables in equation
vars!(varsbuf, eq; op = Symbolics.Operator)
# add the equation to the hyperedge
push!(hyperedge, ClockVertex.Equation(ieq))
for var in varsbuf
idx = get(var_to_idx, var, nothing)
# if this is just a single variable, add it to the hyperedge
if idx isa Int
push!(hyperedge, ClockVertex.Variable(idx))
# we don't immediately `continue` here because this variable might be a
# `Sample` or similar and we want the clock information from it if it is.
end
# now we only care about synchronous operators
iscall(var) || continue
op = operation(var)
is_synchronous_operator(op) || continue

# arguments and corresponding time domains
args = arguments(var)
tdomains = input_timedomain(op)
nargs = length(args)
ndoms = length(tdomains)
if nargs != ndoms
throw(ArgumentError("""
Operator $op applied to $nargs arguments $args but only returns $ndoms \
domains $tdomains from `input_timedomain`.
"""))
end

# each relative clock mapping is only valid per operator application
empty!(relative_hyperedges)
for (arg, domain) in zip(args, tdomains)
empty!(arg_varsbuf)
empty!(arg_hyperedge)
# get variables in argument
vars!(arg_varsbuf, arg; op = Union{Differential, Shift})
# get hyperedge for involved variables
for v in arg_varsbuf
vidx = get(var_to_idx, v, nothing)
vidx === nothing && continue
push!(arg_hyperedge, ClockVertex.Variable(vidx))
end

Moshi.Match.@match domain begin
# If the time domain for this argument is a clock, then all variables in this edge have that clock.
x::SciMLBase.AbstractClock => begin
# add the clock to the edge
push!(arg_hyperedge, ClockVertex.Clock(x))
# add the edge to the graph
add_edge!(inference_graph, arg_hyperedge)
end
# We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
# involved variables have the same clock.
InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge)
# All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
# add the edge, and instead add this to the `relative_hyperedges` mapping.
InferredClock.InferredDiscrete(i) => begin
relative_edge = get!(() -> Set{ClockVertex.Type}(), relative_hyperedges, i)
union!(relative_edge, arg_hyperedge)
end
end
end

outdomain = output_timedomain(op)
Moshi.Match.@match outdomain begin
x::SciMLBase.AbstractClock => begin
push!(hyperedge, ClockVertex.Clock(x))
end
InferredClock.Inferred() => nothing
InferredClock.InferredDiscrete(i) => begin
buffer = get(relative_hyperedges, i, nothing)
if buffer !== nothing
union!(hyperedge, buffer)
delete!(relative_hyperedges, i)
end
end
end

for (_, relative_edge) in relative_hyperedges
add_edge!(inference_graph, relative_edge)
end
end

add_edge!(inference_graph, hyperedge)
end
cc = connected_components(var_graph)
for c′ in cc
c = BitSet(c′)
idxs = intersect(c, inferred)
isempty(idxs) && continue
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
display(fullvars[c′])
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))

clock_partitions = connectionsets(inference_graph)
for partition in clock_partitions
clockidxs = findall(vert -> Moshi.Data.isa_variant(vert, ClockVertex.Clock), partition)
if isempty(clockidxs)
vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)]
throw(ArgumentError("""
Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]).
"""))
end
vd = var_domain[first(idxs)]
for v in c′
var_domain[v] = vd
if length(clockidxs) > 1
vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)]
clks = [vert.:1 for vert in view(partition, clockidxs)]
throw(ArgumentError("""
Found clock partition with multiple associated clocks. Involved variables: \
$(fullvars[vidxs]). Involved clocks: $(clks).
"""))
end
end

for v in 𝑑vertices(graph)
vd = var_domain[v]
eqs = 𝑑neighbors(graph, v)
isempty(eqs) && continue
for eq in eqs
eq_domain[eq] = vd
clock = partition[only(clockidxs)].:1
for vert in partition
Moshi.Match.@match vert begin
ClockVertex.Variable(i) => (var_domain[i] = clock)
ClockVertex.Equation(i) => (eq_domain[i] = clock)
ClockVertex.Clock(_) => nothing
end
end
end

Expand Down
Loading
Loading