Skip to content

Commit feebc16

Browse files
authored
Merge pull request #3808 from SciML/as/clock-inference
feat: rewrite clock inference to support polyadic synchronous operators
2 parents cd98296 + c4cc348 commit feebc16

File tree

4 files changed

+189
-50
lines changed

4 files changed

+189
-50
lines changed

src/clock.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
@data InferredClock begin
22
Inferred
3-
InferredDiscrete
3+
InferredDiscrete(Int)
44
end
55

66
const InferredTimeDomain = InferredClock.Type
77
using .InferredClock: Inferred, InferredDiscrete
88

9+
function InferredClock.InferredDiscrete()
10+
return InferredDiscrete(0)
11+
end
12+
913
Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)
1014

1115
struct VariableTimeDomain end

src/discretedomain.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,15 @@ are not transparent but `Sample` and `Hold` are. Defaults to `false` if not impl
1010
is_transparent_operator(x) = is_transparent_operator(typeof(x))
1111
is_transparent_operator(::Type) = false
1212

13+
"""
14+
$(TYPEDSIGNATURES)
15+
16+
Trait to be implemented for operators which determines whether they are synchronous operators.
17+
Synchronous operators must implement `input_timedomain` and `output_timedomain`.
18+
"""
19+
is_synchronous_operator(x) = is_synchronous_operator(typeof(x))
20+
is_synchronous_operator(::Type) = false
21+
1322
"""
1423
function SampleTime()
1524
@@ -52,6 +61,7 @@ struct Shift <: Operator
5261
end
5362
Shift(steps::Int) = new(nothing, steps)
5463
normalize_to_differential(s::Shift) = Differential(s.t)^s.steps
64+
is_synchronous_operator(::Type{Shift}) = true
5565
Base.nameof(::Shift) = :Shift
5666
SymbolicUtils.isbinop(::Shift) = false
5767

@@ -138,6 +148,7 @@ struct Sample <: Operator
138148
Sample(clock::Union{TimeDomain, InferredTimeDomain} = InferredDiscrete()) = new(clock)
139149
end
140150

151+
is_synchronous_operator(::Type{Sample}) = true
141152
is_transparent_operator(::Type{Sample}) = true
142153

143154
function Sample(arg::Real)
@@ -193,6 +204,7 @@ struct Hold <: Operator
193204
end
194205

195206
is_transparent_operator(::Type{Hold}) = true
207+
is_synchronous_operator(::Type{Hold}) = true
196208

197209
(D::Hold)(x) = Term{symtype(x)}(D, Any[x])
198210
(D::Hold)(x::Num) = Num(D(value(x)))
@@ -314,12 +326,13 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
314326
input_timedomain(op::Operator)
315327
316328
Return the time-domain type (`ContinuousClock()` or `InferredDiscrete()`) that `op` operates on.
329+
Should return a tuple containing the time domain type for each argument to the operator.
317330
"""
318331
function input_timedomain(s::Shift, arg = nothing)
319332
if has_time_domain(arg)
320333
return get_time_domain(arg)
321334
end
322-
InferredDiscrete()
335+
(InferredDiscrete(),)
323336
end
324337

325338
"""
@@ -334,22 +347,20 @@ function output_timedomain(s::Shift, arg = nothing)
334347
InferredDiscrete()
335348
end
336349

337-
input_timedomain(::Sample, _ = nothing) = ContinuousClock()
350+
input_timedomain(::Sample, _ = nothing) = (ContinuousClock(),)
338351
output_timedomain(s::Sample, _ = nothing) = s.clock
339352

340353
function input_timedomain(h::Hold, arg = nothing)
341354
if has_time_domain(arg)
342355
return get_time_domain(arg)
343356
end
344-
InferredDiscrete() # the Hold accepts any discrete
357+
(InferredDiscrete(),) # the Hold accepts any discrete
345358
end
346359
output_timedomain(::Hold, _ = nothing) = ContinuousClock()
347360

348361
sampletime(op::Sample, _ = nothing) = sampletime(op.clock)
349362
sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock)
350363

351-
changes_domain(op) = isoperator(op, Union{Sample, Hold})
352-
353364
function output_timedomain(x)
354365
if isoperator(x, Operator)
355366
return output_timedomain(operation(x), arguments(x)[])

src/systems/clock_inference.jl

Lines changed: 154 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,17 @@
1+
@data ClockVertex begin
2+
Variable(Int)
3+
Equation(Int)
4+
Clock(SciMLBase.AbstractClock)
5+
end
6+
17
struct ClockInference{S}
28
"""Tearing state."""
39
ts::S
410
"""The time domain (discrete clock, continuous) of each equation."""
511
eq_domain::Vector{TimeDomain}
612
"""The output time domain (discrete clock, continuous) of each variable."""
713
var_domain::Vector{TimeDomain}
14+
inference_graph::HyperGraph{ClockVertex.Type}
815
"""The set of variables with concrete domains."""
916
inferred::BitSet
1017
end
@@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState)
2229
var_domain[i] = d
2330
end
2431
end
25-
ClockInference(ts, eq_domain, var_domain, inferred)
32+
inference_graph = HyperGraph{ClockVertex.Type}()
33+
for i in 1:nsrcs(graph)
34+
add_vertex!(inference_graph, ClockVertex.Equation(i))
35+
end
36+
for i in 1:ndsts(graph)
37+
varvert = ClockVertex.Variable(i)
38+
add_vertex!(inference_graph, varvert)
39+
v = ts.fullvars[i]
40+
d = get_time_domain(v)
41+
is_concrete_time_domain(d) || continue
42+
dvert = ClockVertex.Clock(d)
43+
add_vertex!(inference_graph, dvert)
44+
add_edge!(inference_graph, (varvert, dvert))
45+
end
46+
ClockInference(ts, eq_domain, var_domain, inference_graph, inferred)
2647
end
2748

2849
struct NotInferredTimeDomain end
@@ -75,47 +96,147 @@ end
7596
Update the equation-to-time domain mapping by inferring the time domain from the variables.
7697
"""
7798
function infer_clocks!(ci::ClockInference)
78-
@unpack ts, eq_domain, var_domain, inferred = ci
99+
@unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
79100
@unpack var_to_diff, graph = ts.structure
80101
fullvars = get_fullvars(ts)
81102
isempty(inferred) && return ci
82-
# TODO: add a graph type to do this lazily
83-
var_graph = SimpleGraph(ndsts(graph))
84-
for eq in 𝑠vertices(graph)
85-
vvs = 𝑠neighbors(graph, eq)
86-
if !isempty(vvs)
87-
fv, vs = Iterators.peel(vvs)
88-
for v in vs
89-
add_edge!(var_graph, fv, v)
90-
end
91-
end
103+
104+
var_to_idx = Dict(fullvars .=> eachindex(fullvars))
105+
106+
# all shifted variables have the same clock as the unshifted variant
107+
for (i, v) in enumerate(fullvars)
108+
iscall(v) || continue
109+
operation(v) isa Shift || continue
110+
unshifted = only(arguments(v))
111+
add_edge!(inference_graph, (ClockVertex.Variable(i), ClockVertex.Variable(var_to_idx[unshifted])))
92112
end
93-
for v in vertices(var_to_diff)
94-
if (v′ = var_to_diff[v]) !== nothing
95-
add_edge!(var_graph, v, v′)
113+
114+
# preallocated buffers:
115+
# variables in each equation
116+
varsbuf = Set()
117+
# variables in each argument to an operator
118+
arg_varsbuf = Set()
119+
# hyperedge for each equation
120+
hyperedge = Set{ClockVertex.Type}()
121+
# hyperedge for each argument to an operator
122+
arg_hyperedge = Set{ClockVertex.Type}()
123+
# mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
124+
relative_hyperedges = Dict{Int, Set{ClockVertex.Type}}()
125+
126+
for (ieq, eq) in enumerate(equations(ts))
127+
empty!(varsbuf)
128+
empty!(hyperedge)
129+
# get variables in equation
130+
vars!(varsbuf, eq; op = Symbolics.Operator)
131+
# add the equation to the hyperedge
132+
push!(hyperedge, ClockVertex.Equation(ieq))
133+
for var in varsbuf
134+
idx = get(var_to_idx, var, nothing)
135+
# if this is just a single variable, add it to the hyperedge
136+
if idx isa Int
137+
push!(hyperedge, ClockVertex.Variable(idx))
138+
# we don't immediately `continue` here because this variable might be a
139+
# `Sample` or similar and we want the clock information from it if it is.
140+
end
141+
# now we only care about synchronous operators
142+
iscall(var) || continue
143+
op = operation(var)
144+
is_synchronous_operator(op) || continue
145+
146+
# arguments and corresponding time domains
147+
args = arguments(var)
148+
tdomains = input_timedomain(op)
149+
nargs = length(args)
150+
ndoms = length(tdomains)
151+
if nargs != ndoms
152+
throw(ArgumentError("""
153+
Operator $op applied to $nargs arguments $args but only returns $ndoms \
154+
domains $tdomains from `input_timedomain`.
155+
"""))
156+
end
157+
158+
# each relative clock mapping is only valid per operator application
159+
empty!(relative_hyperedges)
160+
for (arg, domain) in zip(args, tdomains)
161+
empty!(arg_varsbuf)
162+
empty!(arg_hyperedge)
163+
# get variables in argument
164+
vars!(arg_varsbuf, arg; op = Union{Differential, Shift})
165+
# get hyperedge for involved variables
166+
for v in arg_varsbuf
167+
vidx = get(var_to_idx, v, nothing)
168+
vidx === nothing && continue
169+
push!(arg_hyperedge, ClockVertex.Variable(vidx))
170+
end
171+
172+
Moshi.Match.@match domain begin
173+
# If the time domain for this argument is a clock, then all variables in this edge have that clock.
174+
x::SciMLBase.AbstractClock => begin
175+
# add the clock to the edge
176+
push!(arg_hyperedge, ClockVertex.Clock(x))
177+
# add the edge to the graph
178+
add_edge!(inference_graph, arg_hyperedge)
179+
end
180+
# We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
181+
# involved variables have the same clock.
182+
InferredClock.Inferred() => add_edge!(inference_graph, arg_hyperedge)
183+
# All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
184+
# add the edge, and instead add this to the `relative_hyperedges` mapping.
185+
InferredClock.InferredDiscrete(i) => begin
186+
relative_edge = get!(() -> Set{ClockVertex.Type}(), relative_hyperedges, i)
187+
union!(relative_edge, arg_hyperedge)
188+
end
189+
end
190+
end
191+
192+
outdomain = output_timedomain(op)
193+
Moshi.Match.@match outdomain begin
194+
x::SciMLBase.AbstractClock => begin
195+
push!(hyperedge, ClockVertex.Clock(x))
196+
end
197+
InferredClock.Inferred() => nothing
198+
InferredClock.InferredDiscrete(i) => begin
199+
buffer = get(relative_hyperedges, i, nothing)
200+
if buffer !== nothing
201+
union!(hyperedge, buffer)
202+
delete!(relative_hyperedges, i)
203+
end
204+
end
205+
end
206+
207+
for (_, relative_edge) in relative_hyperedges
208+
add_edge!(inference_graph, relative_edge)
209+
end
96210
end
211+
212+
add_edge!(inference_graph, hyperedge)
97213
end
98-
cc = connected_components(var_graph)
99-
for c′ in cc
100-
c = BitSet(c′)
101-
idxs = intersect(c, inferred)
102-
isempty(idxs) && continue
103-
if !allequal(iscontinuous(var_domain[i]) for i in idxs)
104-
display(fullvars[c′])
105-
throw(ClockInferenceException("Clocks are not consistent in connected component $(fullvars[c′])"))
214+
215+
clock_partitions = connectionsets(inference_graph)
216+
for partition in clock_partitions
217+
clockidxs = findall(vert -> Moshi.Data.isa_variant(vert, ClockVertex.Clock), partition)
218+
if isempty(clockidxs)
219+
vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)]
220+
throw(ArgumentError("""
221+
Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]).
222+
"""))
106223
end
107-
vd = var_domain[first(idxs)]
108-
for v in c′
109-
var_domain[v] = vd
224+
if length(clockidxs) > 1
225+
vidxs = Int[vert.:1 for vert in partition if Moshi.Data.isa_variant(vert, ClockVertex.Variable)]
226+
clks = [vert.:1 for vert in view(partition, clockidxs)]
227+
throw(ArgumentError("""
228+
Found clock partition with multiple associated clocks. Involved variables: \
229+
$(fullvars[vidxs]). Involved clocks: $(clks).
230+
"""))
110231
end
111-
end
112232

113-
for v in 𝑑vertices(graph)
114-
vd = var_domain[v]
115-
eqs = 𝑑neighbors(graph, v)
116-
isempty(eqs) && continue
117-
for eq in eqs
118-
eq_domain[eq] = vd
233+
clock = partition[only(clockidxs)].:1
234+
for vert in partition
235+
Moshi.Match.@match vert begin
236+
ClockVertex.Variable(i) => (var_domain[i] = clock)
237+
ClockVertex.Equation(i) => (eq_domain[i] = clock)
238+
ClockVertex.Clock(_) => nothing
239+
end
119240
end
120241
end
121242

0 commit comments

Comments
 (0)