Skip to content

refactor simplification and generate_control_function #1680

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 9 additions & 33 deletions src/inputoutput.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,38 +170,28 @@ The return values also include the remaining states and parameters, in the order
# Example
```
using ModelingToolkit: generate_control_function, varmap_to_vars, defaults
f, dvs, ps = generate_control_function(sys, expression=Val{false}, simplify=true)
f, dvs, ps = generate_control_function(sys, expression=Val{false}, simplify=false)
p = varmap_to_vars(defaults(sys), ps)
x = varmap_to_vars(defaults(sys), dvs)
t = 0
f[1](x, inputs, p, t)
```
"""
function generate_control_function(sys::AbstractODESystem;
function generate_control_function(sys::AbstractODESystem, inputs = unbound_inputs(sys);
implicit_dae = false,
has_difference = false,
simplify = true,
simplify = false,
kwargs...)
ctrls = unbound_inputs(sys)
if isempty(ctrls)
if isempty(inputs)
error("No unbound inputs were found in system.")
end

# One can either connect unbound inputs to new parameters and allow structural_simplify, but then the unbound inputs appear as states :( .
# One can also just remove them from the states and parameters for the purposes of code generation, but then structural_simplify fails :(
# To have the best of both worlds, all unbound inputs must be converted to `@parameters` in which case structural_simplify handles them correctly :)
sys = toparam(sys, ctrls)

if simplify
sys = structural_simplify(sys)
end
sys, diff_idxs, alge_idxs = io_preprocessing(sys, inputs, []; simplify,
check_bound = false, kwargs...)

dvs = states(sys)
ps = parameters(sys)

dvs = setdiff(dvs, ctrls)
ps = setdiff(ps, ctrls)
inputs = map(x -> time_varying_as_func(value(x), sys), ctrls)
ps = setdiff(ps, inputs)
inputs = map(x -> time_varying_as_func(value(x), sys), inputs)

eqs = [eq for eq in equations(sys) if !isdifferenceeq(eq)]
check_operator_variables(eqs, Differential)
Expand All @@ -223,24 +213,10 @@ function generate_control_function(sys::AbstractODESystem;
end
pre, sol_states = get_substitutions_and_solved_states(sys)
f = build_function(rhss, args...; postprocess_fbody = pre, states = sol_states,
kwargs...)
expression = Val{false}, kwargs...)
f, dvs, ps
end

"""
toparam(sys, ctrls::AbstractVector)

Transform all instances of `@varibales` in `ctrls` appearing as states and in equations of `sys` with similarly named `@parameters`. This allows [`structural_simplify`](@ref)(sys) in the presence unbound inputs.
"""
function toparam(sys, ctrls::AbstractVector)
eqs = equations(sys)
subs = Dict(ctrls .=> toparam.(ctrls))
eqs = map(eqs) do eq
substitute(eq.lhs, subs) ~ substitute(eq.rhs, subs)
end
ODESystem(eqs, name = nameof(sys))
end

function inputs_to_parameters!(state::TransformationState, check_bound = true)
@unpack structure, fullvars, sys = state
@unpack var_to_diff, graph, solvable_graph = structure
Expand Down
70 changes: 37 additions & 33 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -948,10 +948,17 @@ function will be applied during the tearing process. It also takes kwargs
`allow_symbolic=false` and `allow_parameter=true` which limits the coefficient
types during tearing.
"""
function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
function structural_simplify(sys::AbstractSystem, args...; kwargs...)
sys = expand_connections(sys)
state = TearingState(sys)
state, = inputs_to_parameters!(state)
sys, input_idxs = _structural_simplify(sys, state, args...; kwargs...)
sys
end

function _structural_simplify(sys::AbstractSystem, state; simplify = false,
check_bound = true,
kwargs...)
state, input_idxs = inputs_to_parameters!(state, check_bound)
sys = alias_elimination!(state)
state = TearingState(sys)
check_consistency(state)
Expand All @@ -964,7 +971,31 @@ function structural_simplify(sys::AbstractSystem; simplify = false, kwargs...)
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
@set! sys.observed = topsort_equations(observed(sys), fullstates)
invalidate_cache!(sys)
return sys
return sys, input_idxs
end

function io_preprocessing(sys::AbstractSystem, inputs,
outputs; simplify = false, kwargs...)
sys = expand_connections(sys)
state = TearingState(sys)
markio!(state, inputs, outputs)
sys, input_idxs = _structural_simplify(sys, state; simplify, check_bound = false,
kwargs...)

eqs = equations(sys)
check_operator_variables(eqs, Differential)
# Sort equations and states such that diff.eqs. match differential states and the rest are algebraic
diffstates = collect_operator_variables(sys, Differential)
eqs = sort(eqs, by = e -> !isoperator(e.lhs, Differential),
alg = Base.Sort.DEFAULT_STABLE)
@set! sys.eqs = eqs
diffstates = [arguments(e.lhs)[1] for e in eqs[1:length(diffstates)]]
sts = [diffstates; setdiff(states(sys), diffstates)]
@set! sys.states = sts
diff_idxs = 1:length(diffstates)
alge_idxs = (length(diffstates) + 1):length(sts)

sys, diff_idxs, alge_idxs, input_idxs
end

"""
Expand Down Expand Up @@ -994,36 +1025,9 @@ See also [`linearize`](@ref) which provides a higher-level interface.
function linearization_function(sys::AbstractSystem, inputs,
outputs; simplify = false,
kwargs...)
sys = expand_connections(sys)
state = TearingState(sys)
markio!(state, inputs, outputs)
state, input_idxs = inputs_to_parameters!(state, false)
sys = alias_elimination!(state)
state = TearingState(sys)
check_consistency(state)
if sys isa ODESystem
sys = dae_order_lowering(dummy_derivative(sys, state))
end
state = TearingState(sys)
find_solvables!(state; kwargs...)
sys = tearing_reassemble(state, tearing(state), simplify = simplify)
fullstates = [map(eq -> eq.lhs, observed(sys)); states(sys)]
@set! sys.observed = topsort_equations(observed(sys), fullstates)
invalidate_cache!(sys)

eqs = equations(sys)
check_operator_variables(eqs, Differential)
# Sort equations and states such that diff.eqs. match differential states and the rest are algebraic
diffstates = collect_operator_variables(sys, Differential)
eqs = sort(eqs, by = e -> !isoperator(e.lhs, Differential),
alg = Base.Sort.DEFAULT_STABLE)
@set! sys.eqs = eqs
diffstates = [arguments(e.lhs)[1] for e in eqs[1:length(diffstates)]]
sts = [diffstates; setdiff(states(sys), diffstates)]
@set! sys.states = sts

diff_idxs = 1:length(diffstates)
alge_idxs = (length(diffstates) + 1):length(sts)
sys, diff_idxs, alge_idxs, input_idxs = io_preprocessing(sys, inputs, outputs; simplify,
kwargs...)
sts = states(sys)
fun = ODEFunction(sys)
lin_fun = let fun = fun,
h = ModelingToolkit.build_explicit_observed_function(sys, outputs)
Expand Down
6 changes: 2 additions & 4 deletions test/input_output_handling.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,8 +108,7 @@ eqs = [
]

@named sys = ODESystem(eqs)
f, dvs, ps = ModelingToolkit.generate_control_function(sys, expression = Val{false},
simplify = true)
f, dvs, ps = ModelingToolkit.generate_control_function(sys, simplify = true)

@test isequal(dvs[], x)
@test isempty(ps)
Expand Down Expand Up @@ -170,8 +169,7 @@ eqs = [connect_sd(sd, mass1, mass2)
@named _model = ODESystem(eqs, t)
@named model = compose(_model, mass1, mass2, sd);

f, dvs, ps = ModelingToolkit.generate_control_function(model, expression = Val{false},
simplify = true)
f, dvs, ps = ModelingToolkit.generate_control_function(model, simplify = true)
@test length(dvs) == 4
@test length(ps) == length(parameters(model))
p = ModelingToolkit.varmap_to_vars(ModelingToolkit.defaults(model), ps)
Expand Down