Skip to content

Commit f747b8b

Browse files
fix: maintain order of inputs during complete
1 parent 196a967 commit f747b8b

File tree

2 files changed

+46
-40
lines changed

2 files changed

+46
-40
lines changed

src/systems/abstractsystem.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,8 @@ function complete(
652652
# Ideally we'd do `get_ps` but if `flatten = false`
653653
# we don't get all of them. So we call `parameters`.
654654
all_ps = parameters(sys; initial_parameters = true)
655+
# inputs have to be maintained in a specific order
656+
input_vars = inputs(sys)
655657
if !isempty(all_ps)
656658
# reorder parameters by portions
657659
ps_split = reorder_parameters(sys, all_ps)
@@ -670,6 +672,12 @@ function complete(
670672
end
671673
ordered_ps = vcat(
672674
ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[]))
675+
if isscheduled(sys)
676+
# ensure inputs are sorted
677+
input_idxs = findfirst.(isequal.(input_vars), (ordered_ps,))
678+
@assert all(!isnothing, input_idxs)
679+
@assert issorted(input_idxs)
680+
end
673681
@set! sys.ps = ordered_ps
674682
end
675683
elseif has_index_cache(sys)

src/systems/index_cache.jl

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -97,8 +97,8 @@ function IndexCache(sys::AbstractSystem)
9797
end
9898
end
9999

100-
tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
101-
initial_param_buffers = Dict{Any, Set{BasicSymbolic}}()
100+
tunable_pars = BasicSymbolic[]
101+
initial_pars = BasicSymbolic[]
102102
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
103103
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()
104104

@@ -107,6 +107,10 @@ function IndexCache(sys::AbstractSystem)
107107
buf = get!(buffers, ctype, S())
108108
push!(buf, sym)
109109
end
110+
function insert_by_type!(buffers::Vector{BasicSymbolic}, sym, ctype)
111+
sym = unwrap(sym)
112+
push!(buffers, sym)
113+
end
110114

111115
disc_param_callbacks = Dict{SymbolicParam, Set{Int}}()
112116
events = vcat(continuous_events(sys), discrete_events(sys))
@@ -210,9 +214,9 @@ function IndexCache(sys::AbstractSystem)
210214
ctype <: AbstractArray{Real} ||
211215
ctype <: AbstractArray{<:AbstractFloat})
212216
if iscall(p) && operation(p) isa Initial
213-
initial_param_buffers
217+
initial_pars
214218
else
215-
tunable_buffers
219+
tunable_pars
216220
end
217221
else
218222
constant_buffers
@@ -255,47 +259,41 @@ function IndexCache(sys::AbstractSystem)
255259

256260
tunable_idxs = TunableIndexMap()
257261
tunable_buffer_size = 0
258-
bufferlist = is_initializesystem(sys) ? (tunable_buffers, initial_param_buffers) :
259-
(tunable_buffers,)
260-
for buffers in bufferlist
261-
for (i, (_, buf)) in enumerate(buffers)
262-
for (j, p) in enumerate(buf)
263-
idx = if size(p) == ()
264-
tunable_buffer_size + 1
265-
else
266-
reshape(
267-
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
268-
end
269-
tunable_buffer_size += length(p)
270-
tunable_idxs[p] = idx
271-
tunable_idxs[default_toterm(p)] = idx
272-
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
273-
symbol_to_variable[getname(p)] = p
274-
symbol_to_variable[getname(default_toterm(p))] = p
275-
end
276-
end
262+
if is_initializesystem(sys)
263+
append!(tunable_pars, initial_pars)
264+
empty!(initial_pars)
265+
end
266+
for p in tunable_pars
267+
idx = if size(p) == ()
268+
tunable_buffer_size + 1
269+
else
270+
reshape(
271+
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
272+
end
273+
tunable_buffer_size += length(p)
274+
tunable_idxs[p] = idx
275+
tunable_idxs[default_toterm(p)] = idx
276+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
277+
symbol_to_variable[getname(p)] = p
278+
symbol_to_variable[getname(default_toterm(p))] = p
277279
end
278280
end
279281

280282
initials_idxs = TunableIndexMap()
281283
initials_buffer_size = 0
282-
if !is_initializesystem(sys)
283-
for (i, (_, buf)) in enumerate(initial_param_buffers)
284-
for (j, p) in enumerate(buf)
285-
idx = if size(p) == ()
286-
initials_buffer_size + 1
287-
else
288-
reshape(
289-
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
290-
end
291-
initials_buffer_size += length(p)
292-
initials_idxs[p] = idx
293-
initials_idxs[default_toterm(p)] = idx
294-
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
295-
symbol_to_variable[getname(p)] = p
296-
symbol_to_variable[getname(default_toterm(p))] = p
297-
end
298-
end
284+
for p in initial_pars
285+
idx = if size(p) == ()
286+
initials_buffer_size + 1
287+
else
288+
reshape(
289+
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
290+
end
291+
initials_buffer_size += length(p)
292+
initials_idxs[p] = idx
293+
initials_idxs[default_toterm(p)] = idx
294+
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
295+
symbol_to_variable[getname(p)] = p
296+
symbol_to_variable[getname(default_toterm(p))] = p
299297
end
300298
end
301299

0 commit comments

Comments
 (0)