Skip to content

fix: fix bad input ordering, make downstream tests pass #3804

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@ end
const ODESystem = IntermediateDeprecationSystem

function IntermediateDeprecationSystem(args...; kwargs...)
Base.depwarn("`ODESystem(args...; kwargs...)` is deprecated. Use `System(args...; kwargs...) instead`.", :ODESystem)
Base.depwarn(
"`ODESystem(args...; kwargs...)` is deprecated. Use `System(args...; kwargs...) instead`.",
:ODESystem)

return System(args...; kwargs...)
end
Expand Down
2 changes: 1 addition & 1 deletion src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ function generate_control_function(sys::AbstractSystem, inputs = unbound_inputs(
eval_module = @__MODULE__,
kwargs...)
isempty(inputs) && @warn("No unbound inputs were found in system.")
if !iscomplete(sys)
if !isscheduled(sys)
sys = mtkcompile(sys; inputs, disturbance_inputs)
end
if disturbance_inputs !== nothing
Expand Down
59 changes: 50 additions & 9 deletions src/linearization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,50 @@ function linearize_symbolic(sys::AbstractSystem, inputs,
(; A, B, C, D, f_x, f_z, g_x, g_z, f_u, g_u, h_x, h_z, h_u), sys
end

struct IONotFoundError <: Exception
variant::String
sysname::Symbol
not_found::Any
end

function Base.showerror(io::IO, err::IONotFoundError)
println(io,
"The following $(err.variant) provided to `mtkcompile` were not found in the system:")
maybe_namespace_issue = false
for var in err.not_found
println(io, " ", var)
if hasname(var) && startswith(string(getname(var)), string(err.sysname))
maybe_namespace_issue = true
end
end
if maybe_namespace_issue
println(io, """
Some of the missing variables are namespaced with the name of the system \
`$(err.sysname)` passed to `mtkcompile`. This may be indicative of a namespacing \
issue. `mtkcompile` requires that the $(err.variant) provided are not namespaced \
with the name of the root system. This issue can occur when using `getproperty` \
to access the variables passed as $(err.variant). For example:

```julia
@named sys = MyModel()
inputs = [sys.input_var]
mtkcompile(sys; inputs)
```

Here, `mtkcompile` expects the input to be named `input_var`, but since `sys`
performs namespacing, it will be named `sys$(NAMESPACE_SEPARATOR)input_var`. To \
fix this issue, namespacing can be temporarily disabled:

```julia
@named sys = MyModel()
sys_nns = toggle_namespacing(sys, false)
inputs = [sys_nns.input_var]
mtkcompile(sys; inputs)
```
""")
end
end

"""
Modify the variable metadata of system variables to indicate which ones are inputs, outputs, and disturbances. Needed for `inputs`, `outputs`, `disturbances`, `unbound_inputs`, `unbound_outputs` to return the proper subsets.
"""
Expand Down Expand Up @@ -605,19 +649,16 @@ function markio!(state, orig_inputs, inputs, outputs, disturbances; check = true
if check
ikeys = keys(filter(!last, inputset))
if !isempty(ikeys)
error(
"Some specified inputs were not found in system. The following variables were not found ",
ikeys)
throw(IONotFoundError("inputs", nameof(state.sys), ikeys))
end
dkeys = keys(filter(!last, disturbanceset))
if !isempty(dkeys)
error(
"Specified disturbance inputs were not found in system. The following variables were not found ",
ikeys)
throw(IONotFoundError("disturbance inputs", nameof(state.sys), ikeys))
end
okeys = keys(filter(!last, outputset))
if !isempty(okeys)
throw(IONotFoundError("outputs", nameof(state.sys), okeys))
end
(all(values(outputset)) || error(
"Some specified outputs were not found in system. The following Dict indicates the found variables ",
outputset))
end
state, orig_inputs
end
Expand Down
8 changes: 8 additions & 0 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -652,6 +652,8 @@ function complete(
# Ideally we'd do `get_ps` but if `flatten = false`
# we don't get all of them. So we call `parameters`.
all_ps = parameters(sys; initial_parameters = true)
# inputs have to be maintained in a specific order
input_vars = inputs(sys)
if !isempty(all_ps)
# reorder parameters by portions
ps_split = reorder_parameters(sys, all_ps)
Expand All @@ -670,6 +672,12 @@ function complete(
end
ordered_ps = vcat(
ordered_ps, reduce(vcat, ps_split; init = eltype(ordered_ps)[]))
if isscheduled(sys)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite understand where the fix to the ordering happens. Will this assert @assert issorted(input_idxs) guarantee that the order expected by the inputs when they are passed as input arguments to generated control functions is the same as the order specified by the user?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix to ordering is in the IndexCache constructor - that was the problem all along. This check here just enforces the invariant.

# ensure inputs are sorted
input_idxs = findfirst.(isequal.(input_vars), (ordered_ps,))
@assert all(!isnothing, input_idxs)
@assert issorted(input_idxs)
end
@set! sys.ps = ordered_ps
end
elseif has_index_cache(sys)
Expand Down
78 changes: 38 additions & 40 deletions src/systems/index_cache.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ function IndexCache(sys::AbstractSystem)
end
end

tunable_buffers = Dict{Any, Set{BasicSymbolic}}()
initial_param_buffers = Dict{Any, Set{BasicSymbolic}}()
tunable_pars = BasicSymbolic[]
initial_pars = BasicSymbolic[]
constant_buffers = Dict{Any, Set{BasicSymbolic}}()
nonnumeric_buffers = Dict{Any, Set{SymbolicParam}}()

Expand All @@ -107,6 +107,10 @@ function IndexCache(sys::AbstractSystem)
buf = get!(buffers, ctype, S())
push!(buf, sym)
end
function insert_by_type!(buffers::Vector{BasicSymbolic}, sym, ctype)
sym = unwrap(sym)
push!(buffers, sym)
end

disc_param_callbacks = Dict{SymbolicParam, Set{Int}}()
events = vcat(continuous_events(sys), discrete_events(sys))
Expand Down Expand Up @@ -210,9 +214,9 @@ function IndexCache(sys::AbstractSystem)
ctype <: AbstractArray{Real} ||
ctype <: AbstractArray{<:AbstractFloat})
if iscall(p) && operation(p) isa Initial
initial_param_buffers
initial_pars
else
tunable_buffers
tunable_pars
end
else
constant_buffers
Expand Down Expand Up @@ -255,47 +259,41 @@ function IndexCache(sys::AbstractSystem)

tunable_idxs = TunableIndexMap()
tunable_buffer_size = 0
bufferlist = is_initializesystem(sys) ? (tunable_buffers, initial_param_buffers) :
(tunable_buffers,)
for buffers in bufferlist
for (i, (_, buf)) in enumerate(buffers)
for (j, p) in enumerate(buf)
idx = if size(p) == ()
tunable_buffer_size + 1
else
reshape(
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
end
tunable_buffer_size += length(p)
tunable_idxs[p] = idx
tunable_idxs[default_toterm(p)] = idx
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
symbol_to_variable[getname(p)] = p
symbol_to_variable[getname(default_toterm(p))] = p
end
end
if is_initializesystem(sys)
append!(tunable_pars, initial_pars)
empty!(initial_pars)
end
for p in tunable_pars
idx = if size(p) == ()
tunable_buffer_size + 1
else
reshape(
(tunable_buffer_size + 1):(tunable_buffer_size + length(p)), size(p))
end
tunable_buffer_size += length(p)
tunable_idxs[p] = idx
tunable_idxs[default_toterm(p)] = idx
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
symbol_to_variable[getname(p)] = p
symbol_to_variable[getname(default_toterm(p))] = p
end
end

initials_idxs = TunableIndexMap()
initials_buffer_size = 0
if !is_initializesystem(sys)
for (i, (_, buf)) in enumerate(initial_param_buffers)
for (j, p) in enumerate(buf)
idx = if size(p) == ()
initials_buffer_size + 1
else
reshape(
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
end
initials_buffer_size += length(p)
initials_idxs[p] = idx
initials_idxs[default_toterm(p)] = idx
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
symbol_to_variable[getname(p)] = p
symbol_to_variable[getname(default_toterm(p))] = p
end
end
for p in initial_pars
idx = if size(p) == ()
initials_buffer_size + 1
else
reshape(
(initials_buffer_size + 1):(initials_buffer_size + length(p)), size(p))
end
initials_buffer_size += length(p)
initials_idxs[p] = idx
initials_idxs[default_toterm(p)] = idx
if hasname(p) && (!iscall(p) || operation(p) !== getindex)
symbol_to_variable[getname(p)] = p
symbol_to_variable[getname(default_toterm(p))] = p
end
end

Expand Down
4 changes: 2 additions & 2 deletions test/linearize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,12 +151,12 @@ lsys = ModelingToolkit.reorder_unknowns(lsys, desired_order, reverse(desired_ord
@test lsys.D == [4400 -4400]

## Test that there is a warning when input is misspecified
@test_throws "Some specified inputs were not found" linearize(pid,
@test_throws ["inputs provided to `mtkcompile`", "not found"] linearize(pid,
[
pid.reference.u,
pid.measurement.u
], [ctr_output.u])
@test_throws "Some specified outputs were not found" linearize(pid,
@test_throws ["outputs provided to `mtkcompile`", "not found"] linearize(pid,
[
reference.u,
measurement.u
Expand Down
Loading