1
+ @data ClockVertex begin
2
+ Variable (Int)
3
+ Equation (Int)
4
+ Clock (SciMLBase. AbstractClock)
5
+ end
6
+
1
7
struct ClockInference{S}
2
8
""" Tearing state."""
3
9
ts:: S
4
10
""" The time domain (discrete clock, continuous) of each equation."""
5
11
eq_domain:: Vector{TimeDomain}
6
12
""" The output time domain (discrete clock, continuous) of each variable."""
7
13
var_domain:: Vector{TimeDomain}
14
+ inference_graph:: HyperGraph{ClockVertex.Type}
8
15
""" The set of variables with concrete domains."""
9
16
inferred:: BitSet
10
17
end
@@ -22,7 +29,21 @@ function ClockInference(ts::TransformationState)
22
29
var_domain[i] = d
23
30
end
24
31
end
25
- ClockInference (ts, eq_domain, var_domain, inferred)
32
+ inference_graph = HyperGraph {ClockVertex.Type} ()
33
+ for i in 1 : nsrcs (graph)
34
+ add_vertex! (inference_graph, ClockVertex. Equation (i))
35
+ end
36
+ for i in 1 : ndsts (graph)
37
+ varvert = ClockVertex. Variable (i)
38
+ add_vertex! (inference_graph, varvert)
39
+ v = ts. fullvars[i]
40
+ d = get_time_domain (v)
41
+ is_concrete_time_domain (d) || continue
42
+ dvert = ClockVertex. Clock (d)
43
+ add_vertex! (inference_graph, dvert)
44
+ add_edge! (inference_graph, (varvert, dvert))
45
+ end
46
+ ClockInference (ts, eq_domain, var_domain, inference_graph, inferred)
26
47
end
27
48
28
49
struct NotInferredTimeDomain end
75
96
Update the equation-to-time domain mapping by inferring the time domain from the variables.
76
97
"""
77
98
function infer_clocks! (ci:: ClockInference )
78
- @unpack ts, eq_domain, var_domain, inferred = ci
99
+ @unpack ts, eq_domain, var_domain, inferred, inference_graph = ci
79
100
@unpack var_to_diff, graph = ts. structure
80
101
fullvars = get_fullvars (ts)
81
102
isempty (inferred) && return ci
82
- # TODO : add a graph type to do this lazily
83
- var_graph = SimpleGraph (ndsts (graph))
84
- for eq in 𝑠vertices (graph)
85
- vvs = 𝑠neighbors (graph, eq)
86
- if ! isempty (vvs)
87
- fv, vs = Iterators. peel (vvs)
88
- for v in vs
89
- add_edge! (var_graph, fv, v)
90
- end
91
- end
103
+
104
+ var_to_idx = Dict (fullvars .=> eachindex (fullvars))
105
+
106
+ # all shifted variables have the same clock as the unshifted variant
107
+ for (i, v) in enumerate (fullvars)
108
+ iscall (v) || continue
109
+ operation (v) isa Shift || continue
110
+ unshifted = only (arguments (v))
111
+ add_edge! (inference_graph, (ClockVertex. Variable (i), ClockVertex. Variable (var_to_idx[unshifted])))
92
112
end
93
- for v in vertices (var_to_diff)
94
- if (v′ = var_to_diff[v]) != = nothing
95
- add_edge! (var_graph, v, v′)
113
+
114
+ # preallocated buffers:
115
+ # variables in each equation
116
+ varsbuf = Set ()
117
+ # variables in each argument to an operator
118
+ arg_varsbuf = Set ()
119
+ # hyperedge for each equation
120
+ hyperedge = Set {ClockVertex.Type} ()
121
+ # hyperedge for each argument to an operator
122
+ arg_hyperedge = Set {ClockVertex.Type} ()
123
+ # mapping from `i` in `InferredDiscrete(i)` to the vertices in that inferred partition
124
+ relative_hyperedges = Dict {Int, Set{ClockVertex.Type}} ()
125
+
126
+ for (ieq, eq) in enumerate (equations (ts))
127
+ empty! (varsbuf)
128
+ empty! (hyperedge)
129
+ # get variables in equation
130
+ vars! (varsbuf, eq; op = Symbolics. Operator)
131
+ # add the equation to the hyperedge
132
+ push! (hyperedge, ClockVertex. Equation (ieq))
133
+ for var in varsbuf
134
+ idx = get (var_to_idx, var, nothing )
135
+ # if this is just a single variable, add it to the hyperedge
136
+ if idx isa Int
137
+ push! (hyperedge, ClockVertex. Variable (idx))
138
+ # we don't immediately `continue` here because this variable might be a
139
+ # `Sample` or similar and we want the clock information from it if it is.
140
+ end
141
+ # now we only care about synchronous operators
142
+ iscall (var) || continue
143
+ op = operation (var)
144
+ is_synchronous_operator (op) || continue
145
+
146
+ # arguments and corresponding time domains
147
+ args = arguments (var)
148
+ tdomains = input_timedomain (op)
149
+ nargs = length (args)
150
+ ndoms = length (tdomains)
151
+ if nargs != ndoms
152
+ throw (ArgumentError ("""
153
+ Operator $op applied to $nargs arguments $args but only returns $ndoms \
154
+ domains $tdomains from `input_timedomain`.
155
+ """ ))
156
+ end
157
+
158
+ # each relative clock mapping is only valid per operator application
159
+ empty! (relative_hyperedges)
160
+ for (arg, domain) in zip (args, tdomains)
161
+ empty! (arg_varsbuf)
162
+ empty! (arg_hyperedge)
163
+ # get variables in argument
164
+ vars! (arg_varsbuf, arg; op = Union{Differential, Shift})
165
+ # get hyperedge for involved variables
166
+ for v in arg_varsbuf
167
+ vidx = get (var_to_idx, v, nothing )
168
+ vidx === nothing && continue
169
+ push! (arg_hyperedge, ClockVertex. Variable (vidx))
170
+ end
171
+
172
+ Moshi. Match. @match domain begin
173
+ # If the time domain for this argument is a clock, then all variables in this edge have that clock.
174
+ x:: SciMLBase.AbstractClock => begin
175
+ # add the clock to the edge
176
+ push! (arg_hyperedge, ClockVertex. Clock (x))
177
+ # add the edge to the graph
178
+ add_edge! (inference_graph, arg_hyperedge)
179
+ end
180
+ # We only know that this time domain is inferred. Treat it as a unique domain, all we know is that the
181
+ # involved variables have the same clock.
182
+ InferredClock. Inferred () => add_edge! (inference_graph, arg_hyperedge)
183
+ # All `InferredDiscrete` with the same `i` have the same clock (including output domain) so we don't
184
+ # add the edge, and instead add this to the `relative_hyperedges` mapping.
185
+ InferredClock. InferredDiscrete (i) => begin
186
+ relative_edge = get! (() -> Set {ClockVertex.Type} (), relative_hyperedges, i)
187
+ union! (relative_edge, arg_hyperedge)
188
+ end
189
+ end
190
+ end
191
+
192
+ outdomain = output_timedomain (op)
193
+ Moshi. Match. @match outdomain begin
194
+ x:: SciMLBase.AbstractClock => begin
195
+ push! (hyperedge, ClockVertex. Clock (x))
196
+ end
197
+ InferredClock. Inferred () => nothing
198
+ InferredClock. InferredDiscrete (i) => begin
199
+ buffer = get (relative_hyperedges, i, nothing )
200
+ if buffer != = nothing
201
+ union! (hyperedge, buffer)
202
+ delete! (relative_hyperedges, i)
203
+ end
204
+ end
205
+ end
206
+
207
+ for (_, relative_edge) in relative_hyperedges
208
+ add_edge! (inference_graph, relative_edge)
209
+ end
96
210
end
211
+
212
+ add_edge! (inference_graph, hyperedge)
97
213
end
98
- cc = connected_components (var_graph)
99
- for c′ in cc
100
- c = BitSet (c′)
101
- idxs = intersect (c, inferred)
102
- isempty (idxs) && continue
103
- if ! allequal (iscontinuous (var_domain[i]) for i in idxs)
104
- display (fullvars[c′])
105
- throw (ClockInferenceException (" Clocks are not consistent in connected component $(fullvars[c′]) " ))
214
+
215
+ clock_partitions = connectionsets (inference_graph)
216
+ for partition in clock_partitions
217
+ clockidxs = findall (vert -> Moshi. Data. isa_variant (vert, ClockVertex. Clock), partition)
218
+ if isempty (clockidxs)
219
+ vidxs = Int[vert.:1 for vert in partition if Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
220
+ throw (ArgumentError ("""
221
+ Found clock partion with no associated clock. Involved variables: $(fullvars[vidxs]) .
222
+ """ ))
106
223
end
107
- vd = var_domain[first (idxs)]
108
- for v in c′
109
- var_domain[v] = vd
224
+ if length (clockidxs) > 1
225
+ vidxs = Int[vert.:1 for vert in partition if Moshi. Data. isa_variant (vert, ClockVertex. Variable)]
226
+ clks = [vert.:1 for vert in view (partition, clockidxs)]
227
+ throw (ArgumentError ("""
228
+ Found clock partition with multiple associated clocks. Involved variables: \
229
+ $(fullvars[vidxs]) . Involved clocks: $(clks) .
230
+ """ ))
110
231
end
111
- end
112
232
113
- for v in 𝑑vertices (graph)
114
- vd = var_domain[v]
115
- eqs = 𝑑neighbors (graph, v)
116
- isempty (eqs) && continue
117
- for eq in eqs
118
- eq_domain[eq] = vd
233
+ clock = partition[only (clockidxs)]. :1
234
+ for vert in partition
235
+ Moshi. Match. @match vert begin
236
+ ClockVertex. Variable (i) => (var_domain[i] = clock)
237
+ ClockVertex. Equation (i) => (eq_domain[i] = clock)
238
+ ClockVertex. Clock (_) => nothing
239
+ end
119
240
end
120
241
end
121
242
0 commit comments