Skip to content

Control parameter specification for AbstractODESystem, DiscreteSystem #1059

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Jul 2, 2021
Merged
1 change: 1 addition & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ export structural_simplify
export DiscreteSystem, DiscreteProblem

export calculate_jacobian, generate_jacobian, generate_function
export calculate_control_jacobian, generate_control_jacobian
export calculate_tgrad, generate_tgrad
export calculate_gradient, generate_gradient
export calculate_factorized_W, generate_factorized_W
Expand Down
27 changes: 25 additions & 2 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ call will be cached in the system object.
"""
function calculate_jacobian end

"""
```julia
calculate_control_jacobian(sys::AbstractSystem)
```

Calculate the jacobian matrix of a system with respect to the system's controls.

Returns a matrix of [`Num`](@ref) instances. The result from the first
call will be cached in the system object.
"""
function calculate_control_jacobian end

"""
```julia
calculate_factorized_W(sys::AbstractSystem)
Expand Down Expand Up @@ -140,10 +152,12 @@ for prop in [
:iv
:states
:ps
:ctrls
:defaults
:observed
:tgrad
:jac
:ctrl_jac
:Wfact
:Wfact_t
:systems
Expand Down Expand Up @@ -301,6 +315,7 @@ end

namespace_variables(sys::AbstractSystem) = states(sys, states(sys))
namespace_parameters(sys::AbstractSystem) = parameters(sys, parameters(sys))
namespace_controls(sys::AbstractSystem) = controls(sys, controls(sys))

function namespace_defaults(sys)
defs = defaults(sys)
Expand Down Expand Up @@ -344,13 +359,21 @@ function states(sys::AbstractSystem)
systems = get_systems(sys)
unique(isempty(systems) ?
sts :
[sts;reduce(vcat,namespace_variables.(systems))])
[sts; reduce(vcat,namespace_variables.(systems))])
end

function parameters(sys::AbstractSystem)
ps = get_ps(sys)
systems = get_systems(sys)
isempty(systems) ? ps : [ps;reduce(vcat,namespace_parameters.(systems))]
isempty(systems) ? ps : [ps; reduce(vcat,namespace_parameters.(systems))]
end

function controls(sys::AbstractSystem)
ctrls = get_ctrls(sys)
systems = get_systems(sys)
isempty(systems) ? ctrls : [ctrls; reduce(vcat,namespace_controls.(systems))]
end

function observed(sys::AbstractSystem)
iv = independent_variable(sys)
obs = get_observed(sys)
Expand Down
2 changes: 1 addition & 1 deletion src/systems/control/controlsystem.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
abstract type AbstractControlSystem <: AbstractSystem end

function namespace_controls(sys::AbstractSystem)
function namespace_controls(sys::AbstractControlSystem)
[rename(x,renamespace(nameof(sys),nameof(x))) for x in controls(sys)]
end

Expand Down
28 changes: 28 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,28 @@ function calculate_jacobian(sys::AbstractODESystem;
return jac
end

function calculate_control_jacobian(sys::AbstractODESystem;
sparse=false, simplify=false)
cache = get_ctrl_jac(sys)[]
if cache isa Tuple && cache[2] == (sparse, simplify)
return cache[1]
end

rhs = [eq.rhs for eq ∈ equations(sys)]

iv = get_iv(sys)
ctrls = controls(sys)

if sparse
jac = sparsejacobian(rhs, ctrls, simplify=simplify)
else
jac = jacobian(rhs, ctrls, simplify=simplify)
end

get_ctrl_jac(sys)[] = jac, (sparse, simplify) # cache Jacobian
return jac
end

function generate_tgrad(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
simplify=false, kwargs...)
tgrad = calculate_tgrad(sys,simplify=simplify)
Expand All @@ -50,6 +72,12 @@ function generate_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = param
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
end

function generate_control_jacobian(sys::AbstractODESystem, dvs = states(sys), ps = parameters(sys);
simplify=false, sparse = false, kwargs...)
jac = calculate_control_jacobian(sys;simplify=simplify,sparse=sparse)
return build_function(jac, dvs, ps, get_iv(sys); kwargs...)
end

@noinline function throw_invalid_derivative(dervar, eq)
msg = "The derivative variable must be isolated to the left-hand " *
"side of the equation like `$dervar ~ ...`.\n Got $eq."
Expand Down
22 changes: 18 additions & 4 deletions src/systems/diffeqs/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ struct ODESystem <: AbstractODESystem
states::Vector
"""Parameter variables. Must not contain the independent variable."""
ps::Vector
"""Control parameters (some subset of `ps`)."""
ctrls::Vector
"""Observed states."""
observed::Vector{Equation}
"""
Time-derivative matrix. Note: this field will not be defined until
Expand All @@ -43,6 +46,11 @@ struct ODESystem <: AbstractODESystem
"""
jac::RefValue{Any}
"""
Control Jacobian matrix. Note: this field will not be defined until
[`calculate_control_jacobian`](@ref) is called on the system.
"""
ctrl_jac::RefValue{Any}
"""
`Wfact` matrix. Note: this field will not be defined until
[`generate_factorized_W`](@ref) is called on the system.
"""
Expand Down Expand Up @@ -74,16 +82,17 @@ struct ODESystem <: AbstractODESystem
"""
connection_type::Any

function ODESystem(deqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
function ODESystem(deqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
check_variables(dvs,iv)
check_parameters(ps,iv)
check_equations(deqs,iv)
new(deqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
new(deqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, structure, connection_type)
end
end

function ODESystem(
deqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Num[],
systems = ODESystem[],
name=gensym(:ODESystem),
Expand All @@ -92,9 +101,13 @@ function ODESystem(
defaults=_merge(Dict(default_u0), Dict(default_p)),
connection_type=nothing,
)

@assert all(control -> any(isequal.(control, ps)), controls) "All controls must also be parameters."

iv′ = value(scalarize(iv))
dvs′ = value.(scalarize(dvs))
ps′ = value.(scalarize(ps))
ctrl′ = value.(scalarize(controls))

if !(isempty(default_u0) && isempty(default_p))
Base.depwarn("`default_u0` and `default_p` are deprecated. Use `defaults` instead.", :ODESystem, force=true)
Expand All @@ -104,13 +117,14 @@ function ODESystem(

tgrad = RefValue(Vector{Num}(undef, 0))
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
end
ODESystem(deqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
ODESystem(deqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, nothing, connection_type)
end

vars(x::Sym) = Set([x])
Expand Down Expand Up @@ -349,4 +363,4 @@ function convert_system(::Type{<:ODESystem}, sys, t; name=nameof(sys))
neweqs = map(sub, equations(sys))
defs = Dict(sub(k) => sub(v) for (k, v) in defaults(sys))
return ODESystem(neweqs, t, newsts, parameters(sys); defaults=defs, name=name)
end
end
26 changes: 17 additions & 9 deletions src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,10 @@ struct SDESystem <: AbstractODESystem
states::Vector
"""Parameter variables. Must not contain the independent variable."""
ps::Vector
observed::Vector
"""Control parameters (some subset of `ps`)."""
ctrls::Vector
"""Observed states."""
observed::Vector{Equation}
"""
Time-derivative matrix. Note: this field will not be defined until
[`calculate_tgrad`](@ref) is called on the system.
Expand All @@ -49,6 +52,11 @@ struct SDESystem <: AbstractODESystem
"""
jac::RefValue
"""
Control Jacobian matrix. Note: this field will not be defined until
[`calculate_control_jacobian`](@ref) is called on the system.
"""
ctrl_jac::RefValue{Any}
"""
`Wfact` matrix. Note: this field will not be defined until
[`generate_factorized_W`](@ref) is called on the system.
"""
Expand Down Expand Up @@ -76,16 +84,17 @@ struct SDESystem <: AbstractODESystem
"""
connection_type::Any

function SDESystem(deqs, neqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
function SDESystem(deqs, neqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
check_variables(dvs,iv)
check_parameters(ps,iv)
check_equations(deqs,iv)
new(deqs, neqs, iv, dvs, ps, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
new(deqs, neqs, iv, dvs, ps, ctrls, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
end
end

function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
observed = [],
controls = Num[],
observed = Num[],
systems = SDESystem[],
default_u0=Dict(),
default_p=Dict(),
Expand All @@ -96,6 +105,8 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;
iv′ = value(iv)
dvs′ = value.(dvs)
ps′ = value.(ps)
ctrl′ = value.(controls)

sysnames = nameof.(systems)
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
Expand All @@ -108,9 +119,10 @@ function SDESystem(deqs::AbstractVector{<:Equation}, neqs, iv, dvs, ps;

tgrad = RefValue(Vector{Num}(undef, 0))
jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
ctrl_jac = RefValue{Any}(Matrix{Num}(undef, 0, 0))
Wfact = RefValue(Matrix{Num}(undef, 0, 0))
Wfact_t = RefValue(Matrix{Num}(undef, 0, 0))
SDESystem(deqs, neqs, iv′, dvs′, ps′, observed, tgrad, jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
SDESystem(deqs, neqs, iv′, dvs′, ps′, ctrl′, observed, tgrad, jac, ctrl_jac, Wfact, Wfact_t, name, systems, defaults, connection_type)
end

function generate_diffusion_function(sys::SDESystem, dvs = states(sys), ps = parameters(sys); kwargs...)
Expand Down Expand Up @@ -157,10 +169,6 @@ function stochastic_integral_transform(sys::SDESystem, correction_factor)
SDESystem(deqs,get_noiseeqs(sys),get_iv(sys),states(sys),parameters(sys))
end





"""
```julia
function DiffEqBase.SDEFunction{iip}(sys::SDESystem, dvs = sys.states, ps = sys.ps;
Expand Down
15 changes: 10 additions & 5 deletions src/systems/discrete_system/discrete_system.jl
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ struct DiscreteSystem <: AbstractSystem
states::Vector
"""Parameter variables. Must not contain the independent variable."""
ps::Vector
"""Control parameters (some subset of `ps`)."""
ctrls::Vector
"""Observed states."""
observed::Vector{Equation}
"""
Name: the name of the system
Expand All @@ -49,10 +52,10 @@ struct DiscreteSystem <: AbstractSystem
in `DiscreteSystem`.
"""
default_p::Dict
function DiscreteSystem(discreteEqs, iv, dvs, ps, observed, name, systems, default_u0, default_p)
check_variables(dvs, iv)
check_parameters(ps, iv)
new(discreteEqs, iv, dvs, ps, observed, name, systems, default_u0, default_p)
function DiscreteSystem(discreteEqs, iv, dvs, ps, ctrls, observed, name, systems, default_u0, default_p)
check_variables(dvs,iv)
check_parameters(ps,iv)
new(discreteEqs, iv, dvs, ps, ctrls, observed, name, systems, default_u0, default_p)
end
end

Expand All @@ -63,6 +66,7 @@ Constructs a DiscreteSystem.
"""
function DiscreteSystem(
discreteEqs::AbstractVector{<:Equation}, iv, dvs, ps;
controls = Num[],
observed = Num[],
systems = DiscreteSystem[],
name=gensym(:DiscreteSystem),
Expand All @@ -72,6 +76,7 @@ function DiscreteSystem(
iv′ = value(iv)
dvs′ = value.(dvs)
ps′ = value.(ps)
ctrl′ = value.(controls)

default_u0 isa Dict || (default_u0 = Dict(default_u0))
default_p isa Dict || (default_p = Dict(default_p))
Expand All @@ -82,7 +87,7 @@ function DiscreteSystem(
if length(unique(sysnames)) != length(sysnames)
throw(ArgumentError("System names must be unique."))
end
DiscreteSystem(discreteEqs, iv′, dvs′, ps′, observed, name, systems, default_u0, default_p)
DiscreteSystem(discreteEqs, iv′, dvs′, ps′, ctrl′, observed, name, systems, default_u0, default_p)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion test/direct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ end
@variables a,b
X = [a,b]

spoly(x) = simplify(x, polynorm=true)
spoly(x) = simplify(x, expand=true)
rr = rosenbrock(X)

reference_hes = ModelingToolkit.hessian(rr, X)
Expand Down
4 changes: 2 additions & 2 deletions test/discretesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ eqs = [next_S ~ S-infection,
next_R ~ R+recovery]

# System
sys = DiscreteSystem(eqs,t,[S,I,R],[c,nsteps,δt,β,γ])
sys = DiscreteSystem(eqs,t,[S,I,R],[c,nsteps,δt,β,γ]; controls = [β, γ])

# Problem
u0 = [S => 990.0, I => 10.0, R => 0.0]
Expand Down Expand Up @@ -54,4 +54,4 @@ p = [0.05,10.0,0.25,0.1];
prob_map = DiscreteProblem(sir_map!,u0,tspan,p);
sol_map2 = solve(prob_map,FunctionMap());

@test Array(sol_map) ≈ Array(sol_map2)
@test Array(sol_map) ≈ Array(sol_map2)
20 changes: 19 additions & 1 deletion test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -361,4 +361,22 @@ D =Differential(t)
eqs = [D(x1) ~ -x1]
sys = ODESystem(eqs,t,[x1,x2],[])
@test_throws ArgumentError ODEProblem(sys, [1.0,1.0], (0.0,1.0))
prob = ODEProblem(sys, [1.0,1.0], (0.0,1.0), check_length=false)
prob = ODEProblem(sys, [1.0,1.0], (0.0,1.0), check_length=false)

# check inputs
let
@parameters t f k d
@variables x(t) ẋ(t)
δ = Differential(t)

eqs = [δ(x) ~ ẋ, δ(ẋ) ~ f - k*x - d*ẋ]
sys = ODESystem(eqs, t, [x, ẋ], [f, d, k]; controls = [f])

calculate_control_jacobian(sys)

@test isequal(
calculate_control_jacobian(sys),
reshape(Num[0,1], 2, 1)
)

end
2 changes: 1 addition & 1 deletion test/reduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ ref_eq = [
@variables x(t) y(t) z(t) a(t) u(t) F(t)
D = Differential(t)

test_equal(a, b) = @test isequal(simplify(a, polynorm=true), simplify(b, polynorm=true))
test_equal(a, b) = @test isequal(simplify(a, expand=true), simplify(b, expand=true))

eqs = [
D(x) ~ σ*(y-x)
Expand Down