Skip to content

feat: support caching of different types of subexpressions in SCCNonlinearProblem #3324

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jan 16, 2025
3 changes: 2 additions & 1 deletion src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,8 @@ import SCCNonlinearSolve
using Reexport
using RecursiveArrayTools
import Graphs: SimpleDiGraph, add_edge!, incidence_matrix
import BlockArrays: BlockedArray, Block, blocksize, blocksizes
import BlockArrays: BlockArray, BlockedArray, Block, blocksize, blocksizes, blockpush!,
undef_blocks, blocks
import CommonSolve
import EnumX

Expand Down
2 changes: 1 addition & 1 deletion src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1919,7 +1919,7 @@ function Base.show(
nrows > 0 && hint && print(io, " see hierarchy($name)")
for i in 1:nrows
sub = subs[i]
name = String(nameof(sub))
local name = String(nameof(sub))
print(io, "\n ", name)
desc = description(sub)
if !isempty(desc)
Expand Down
118 changes: 91 additions & 27 deletions src/systems/nonlinear/nonlinearsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -573,29 +573,37 @@ function DiffEqBase.NonlinearLeastSquaresProblem{iip}(sys::NonlinearSystem, u0ma
NonlinearLeastSquaresProblem{iip}(f, u0, p; filter_kwargs(kwargs)...)
end

const TypeT = Union{DataType, UnionAll}

struct CacheWriter{F}
fn::F
end

function (cw::CacheWriter)(p, sols)
cw.fn(p.caches[1], sols, p...)
cw.fn(p.caches, sols, p...)
end

function CacheWriter(sys::AbstractSystem, exprs, solsyms, obseqs::Vector{Equation};
function CacheWriter(sys::AbstractSystem, buffer_types::Vector{TypeT},
exprs::Dict{TypeT, Vector{Any}}, solsyms, obseqs::Vector{Equation};
eval_expression = false, eval_module = @__MODULE__)
ps = parameters(sys)
rps = reorder_parameters(sys, ps)
obs_assigns = [eq.lhs ← eq.rhs for eq in obseqs]
cmap, cs = get_cmap(sys)
cmap_assigns = [eq.lhs ← eq.rhs for eq in cmap]

outsyms = [Symbol(:out, i) for i in eachindex(buffer_types)]
body = map(eachindex(buffer_types), buffer_types) do i, T
Symbol(:tmp, i) ← SetArray(true, :(out[$i]), get(exprs, T, []))
end
fn = Func(
[:out, DestructuredArgs(DestructuredArgs.(solsyms)),
DestructuredArgs.(rps)...],
[],
SetArray(true, :out, exprs)
Let(body, :())
) |> wrap_assignments(false, obs_assigns)[2] |>
wrap_parameter_dependencies(sys, false)[2] |>
wrap_array_vars(sys, exprs; dvs = nothing, inputs = [])[2] |>
wrap_array_vars(sys, []; dvs = nothing, inputs = [])[2] |>
wrap_assignments(false, cmap_assigns)[2] |> toexpr
return CacheWriter(eval_or_rgf(fn; eval_expression, eval_module))
end
Expand Down Expand Up @@ -677,8 +685,17 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,

explicitfuns = []
nlfuns = []
prevobsidxs = Int[]
cachesize = 0
prevobsidxs = BlockArray(undef_blocks, Vector{Int}, Int[])
# Cache buffer types and corresponding sizes. Stored as a pair of arrays instead of a
# dict to maintain a consistent order of buffers across SCCs
cachetypes = TypeT[]
cachesizes = Int[]
# explicitfun! related information for each SCC
# We need to compute buffer sizes before doing any codegen
scc_cachevars = Dict{TypeT, Vector{Any}}[]
scc_cacheexprs = Dict{TypeT, Vector{Any}}[]
scc_eqs = Vector{Equation}[]
scc_obs = Vector{Equation}[]
for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
# subset unknowns and equations
_dvs = dvs[vscc]
Expand All @@ -690,11 +707,10 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
_obs = obs[obsidxs]

# get all subexpressions in the RHS which we can precompute in the cache
# precomputed subexpressions should not contain `banned_vars`
banned_vars = Set{Any}(vcat(_dvs, getproperty.(_obs, (:lhs,))))
for var in banned_vars
iscall(var) || continue
operation(var) === getindex || continue
push!(banned_vars, arguments(var)[1])
filter!(banned_vars) do var
symbolic_type(var) != ArraySymbolic() || all(x -> var[i] in banned_vars, eachindex(var))
end
state = Dict()
for i in eachindex(_obs)
Expand All @@ -706,37 +722,85 @@ function SciMLBase.SCCNonlinearProblem{iip}(sys::NonlinearSystem, u0map,
_eqs[i].rhs, banned_vars, state)
end

# cached variables and their corresponding expressions
cachevars = Any[obs[i].lhs for i in prevobsidxs]
cacheexprs = Any[obs[i].lhs for i in prevobsidxs]
# map from symtype to cached variables and their expressions
cachevars = Dict{Union{DataType, UnionAll}, Vector{Any}}()
cacheexprs = Dict{Union{DataType, UnionAll}, Vector{Any}}()
# observed of previous SCCs are in the cache
# NOTE: When we get proper CSE, we can substitute these
# and then use `subexpressions_not_involving_vars!`
for i in prevobsidxs
T = symtype(obs[i].lhs)
buf = get!(() -> Any[], cachevars, T)
push!(buf, obs[i].lhs)

buf = get!(() -> Any[], cacheexprs, T)
push!(buf, obs[i].lhs)
end

for (k, v) in state
push!(cachevars, unwrap(v))
push!(cacheexprs, unwrap(k))
k = unwrap(k)
v = unwrap(v)
T = symtype(k)
buf = get!(() -> Any[], cachevars, T)
push!(buf, v)
buf = get!(() -> Any[], cacheexprs, T)
push!(buf, k)
end
cachesize = max(cachesize, length(cachevars))

# update the sizes of cache buffers
for (T, buf) in cachevars
idx = findfirst(isequal(T), cachetypes)
if idx === nothing
push!(cachetypes, T)
push!(cachesizes, 0)
idx = lastindex(cachetypes)
end
cachesizes[idx] = max(cachesizes[idx], length(buf))
end

push!(scc_cachevars, cachevars)
push!(scc_cacheexprs, cacheexprs)
push!(scc_eqs, _eqs)
push!(scc_obs, _obs)
blockpush!(prevobsidxs, obsidxs)
end

for (i, (escc, vscc)) in enumerate(zip(eq_sccs, var_sccs))
_dvs = dvs[vscc]
_eqs = scc_eqs[i]
_prevobsidxs = reduce(vcat, blocks(prevobsidxs)[1:(i - 1)]; init = Int[])
_obs = scc_obs[i]
cachevars = scc_cachevars[i]
cacheexprs = scc_cacheexprs[i]

if isempty(cachevars)
push!(explicitfuns, Returns(nothing))
else
solsyms = getindex.((dvs,), view(var_sccs, 1:(i - 1)))
push!(explicitfuns,
CacheWriter(sys, cacheexprs, solsyms, obs[prevobsidxs];
CacheWriter(sys, cachetypes, cacheexprs, solsyms, obs[_prevobsidxs];
eval_expression, eval_module))
end

cachebufsyms = Tuple(map(cachetypes) do T
get(cachevars, T, [])
end)
f = SCCNonlinearFunction{iip}(
sys, _eqs, _dvs, _obs, (cachevars,); eval_expression, eval_module, kwargs...)
sys, _eqs, _dvs, _obs, cachebufsyms; eval_expression, eval_module, kwargs...)
push!(nlfuns, f)
append!(cachevars, _dvs)
append!(cacheexprs, _dvs)
for i in obsidxs
push!(cachevars, obs[i].lhs)
push!(cacheexprs, obs[i].rhs)
end
append!(prevobsidxs, obsidxs)
end

if cachesize != 0
p = rebuild_with_caches(p, BufferTemplate(eltype(u0), cachesize))
if !isempty(cachetypes)
templates = map(cachetypes, cachesizes) do T, n
# Real refers to `eltype(u0)`
if T == Real
T = eltype(u0)
elseif T <: Array && eltype(T) == Real
T = Array{eltype(u0), ndims(T)}
end
BufferTemplate(T, n)
end
p = rebuild_with_caches(p, templates...)
end

subprobs = []
Expand Down
27 changes: 18 additions & 9 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1108,23 +1108,33 @@ returns the modified `expr`.
"""
function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
expr = unwrap(expr)
symbolic_type(expr) == NotSymbolic() && return expr
if symbolic_type(expr) == NotSymbolic()
if is_array_of_symbolics(expr)
return map(expr) do el
subexpressions_not_involving_vars!(el, vars, state)
end
end
return expr
end
any(isequal(expr), vars) && return expr
iscall(expr) || return expr
is_variable_floatingpoint(expr) || return expr
symtype(expr) <: Union{Real, AbstractArray{<:Real}} || return expr
Symbolics.shape(expr) == Symbolics.Unknown() && return expr
haskey(state, expr) && return state[expr]
vs = ModelingToolkit.vars(expr)
intersect!(vs, vars)
if isempty(vs)
op = operation(expr)
args = arguments(expr)
# if this is a `getindex` and the getindex-ed value is a `Sym`
# or it is not a called parameter
# OR
# none of `vars` are involved in `expr`
if op === getindex && (issym(args[1]) || !iscalledparameter(args[1])) ||
(vs = ModelingToolkit.vars(expr); intersect!(vs, vars); isempty(vs))
sym = gensym(:subexpr)
stype = symtype(expr)
var = similar_variable(expr, sym)
state[expr] = var
return var
end
op = operation(expr)
args = arguments(expr)

if (op == (+) || op == (*)) && symbolic_type(expr) !== ArraySymbolic()
indep_args = []
dep_args = []
Expand All @@ -1143,7 +1153,6 @@ function subexpressions_not_involving_vars!(expr, vars, state::Dict{Any, Any})
return op(indep_term, dep_term)
end
newargs = map(args) do arg
symbolic_type(arg) != NotSymbolic() || is_array_of_symbolics(arg) || return arg
subexpressions_not_involving_vars!(arg, vars, state)
end
return maketerm(typeof(expr), op, newargs, metadata(expr))
Expand Down
92 changes: 92 additions & 0 deletions test/scc_nonlinear_problem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -93,12 +93,12 @@
β = 1e-6
R0 = 1000
R = 9000
Ue(t) = 0.1 * sin(200 * π * t)

Check warning on line 96 in test/scc_nonlinear_problem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Ue" should be "Use" or "Due".

function transamp(out, du, u, p, t)
g(x) = 1e-6 * (exp(x / 0.026) - 1)
y1, y2, y3, y4, y5, y6, y7, y8 = u
out[1] = -Ue(t) / R0 + y1 / R0 + C[1] * du[1] - C[1] * du[2]

Check warning on line 101 in test/scc_nonlinear_problem.jl

View workflow job for this annotation

GitHub Actions / Spell Check with Typos

"Ue" should be "Use" or "Due".
out[2] = -Ub / R + y2 * 2 / R - (α - 1) * g(y2 - y3) - C[1] * du[1] + C[1] * du[2]
out[3] = -g(y2 - y3) + y3 / R + C[2] * du[3]
out[4] = -Ub / R + y4 / R + α * g(y2 - y3) + C[3] * du[4] - C[3] * du[5]
Expand Down Expand Up @@ -161,3 +161,95 @@
@test SciMLBase.successful_retcode(sccsol)
@test val[] == 1
end

import ModelingToolkitStandardLibrary.Blocks as B
import ModelingToolkitStandardLibrary.Mechanical.Translational as T
import ModelingToolkitStandardLibrary.Hydraulic.IsothermalCompressible as IC

@testset "Caching of subexpressions of different types" begin
liquid_pressure(rho, rho_0, bulk) = (rho / rho_0 - 1) * bulk
gas_pressure(rho, rho_0, p_gas, rho_gas) = rho * ((0 - p_gas) / (rho_0 - rho_gas))
full_pressure(rho, rho_0, bulk, p_gas, rho_gas) = ifelse(
rho >= rho_0, liquid_pressure(rho, rho_0, bulk),
gas_pressure(rho, rho_0, p_gas, rho_gas))

@component function Volume(;
#parameters
area,
direction = +1,
x_int,
name)
pars = @parameters begin
area = area
x_int = x_int
rho_0 = 1000
bulk = 1e9
p_gas = -1000
rho_gas = 1
end

vars = @variables begin
x(t) = x_int
dx(t), [guess = 0]
p(t), [guess = 0]
f(t), [guess = 0]
rho(t), [guess = 0]
m(t), [guess = 0]
dm(t), [guess = 0]
end

systems = @named begin
port = IC.HydraulicPort()
flange = T.MechanicalPort()
end

eqs = [
# connectors
port.p ~ p
port.dm ~ dm
flange.v * direction ~ dx
flange.f * direction ~ -f

# differentials
D(x) ~ dx
D(m) ~ dm

# physics
p ~ full_pressure(rho, rho_0, bulk, p_gas, rho_gas)
f ~ p * area
m ~ rho * x * area]

return ODESystem(eqs, t, vars, pars; name, systems)
end

systems = @named begin
fluid = IC.HydraulicFluid(; bulk_modulus = 1e9)

src1 = IC.Pressure(;)
src2 = IC.Pressure(;)

vol1 = Volume(; area = 0.01, direction = +1, x_int = 0.1)
vol2 = Volume(; area = 0.01, direction = +1, x_int = 0.1)

mass = T.Mass(; m = 10)

sin1 = B.Sine(; frequency = 0.5, amplitude = +0.5e5, offset = 10e5)
sin2 = B.Sine(; frequency = 0.5, amplitude = -0.5e5, offset = 10e5)
end

eqs = [connect(fluid, src1.port)
connect(fluid, src2.port)
connect(src1.port, vol1.port)
connect(src2.port, vol2.port)
connect(vol1.flange, mass.flange, vol2.flange)
connect(src1.p, sin1.output)
connect(src2.p, sin2.output)]

initialization_eqs = [mass.s ~ 0.0
mass.v ~ 0.0]

@mtkbuild sys = ODESystem(eqs, t, [], []; systems, initialization_eqs)
prob = ODEProblem(sys, [], (0, 5))
sol = solve(prob)
@test SciMLBase.successful_retcode(sol)
end
Loading