Skip to content

feat: add SemilinearODEFunction and SemilinearODEProblem #3739

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
PreallocationTools = "d236fae5-4411-538c-8e31-a6e3d9e00b46"
PrecompileTools = "aea7be01-6a6a-4083-8856-8a6e6704d82a"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Expand Down Expand Up @@ -142,6 +143,7 @@ OrdinaryDiffEq = "6.82.0"
OrdinaryDiffEqCore = "1.15.0"
OrdinaryDiffEqDefault = "1.2"
OrdinaryDiffEqNonlinearSolve = "1.5.0"
PreallocationTools = "0.4.27"
PrecompileTools = "1"
Pyomo = "0.1.0"
REPL = "1"
Expand Down
3 changes: 3 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ const DQ = DynamicQuantities
import DifferentiationInterface as DI
using ADTypes: AutoForwardDiff
import SciMLPublic: @public
import PreallocationTools
import PreallocationTools: DiffCache

export @derivatives

Expand Down Expand Up @@ -287,6 +289,7 @@ export IntervalNonlinearProblem
export OptimizationProblem, constraints
export SteadyStateProblem
export JumpProblem
export SemilinearODEFunction, SemilinearODEProblem
export alias_elimination, flatten
export connect, domain_connect, @connector, Connection, AnalysisPoint, Flow, Stream,
instream
Expand Down
136 changes: 135 additions & 1 deletion src/problems/odeproblem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,9 +98,143 @@ end
maybe_codegen_scimlproblem(expression, SteadyStateProblem{iip}, args; kwargs...)
end

struct SemilinearODEFunction{iip, spec} end

@fallback_iip_specialize function SemilinearODEFunction{iip, specialize}(
sys::System; u0 = nothing, p = nothing, t = nothing,
semiquadratic_form = nothing, semiquadratic_jacobian = nothing,
eval_expression = false, eval_module = @__MODULE__,
expression = Val{false}, sparse = false, check_compatibility = true,
jac = false, checkbounds = false, cse = true, initialization_data = nothing,
analytic = nothing, kwargs...) where {iip, specialize}
check_complete(sys, SemilinearODEFunction)
check_compatibility && check_compatible_system(SemilinearODEFunction, sys)

if semiquadratic_form === nothing
sys = add_semilinear_parameters(sys)
semiquadratic_form = calculate_split_form(sys; sparse)
end

A, B, x2, C = semiquadratic_form
M = calculate_massmatrix(sys)
_M = concrete_massmatrix(M; sparse, u0)

f1, f2 = generate_semiquadratic_functions(
sys, A, B, x2, C; expression, wrap_gfw = Val{true},
eval_expression, eval_module, kwargs...)

if jac
semiquadratic_jacobian = @something(semiquadratic_jacobian,
calculate_semiquadratic_jacobian(sys, B, x2, C; sparse, massmatrix = _M))
f1jac, x2jac, Cjac = semiquadratic_jacobian
_jac = generate_semiquadratic_jacobian(
sys, B, x2, C, f1jac, x2jac, Cjac; sparse, expression,
wrap_gfw = Val{true}, eval_expression, eval_module, kwargs...)
_W_sparsity = f1jac
W_prototype = calculate_W_prototype(_W_sparsity; u0, sparse)
else
_jac = nothing
W_prototype = nothing
end

observedfun = ObservedFunctionCache(
sys; expression, steady_state = false, eval_expression, eval_module, checkbounds, cse)

f1_args = (; f1)
f1_kwargs = (; jac = _jac)
f1 = maybe_codegen_scimlfn(
expression, ODEFunction{iip, specialize}, f1_args; f1_kwargs...)
args = (; f1, f2)

kwargs = (;
sys = sys,
jac = _jac,
mass_matrix = _M,
jac_prototype = W_prototype,
observed = observedfun,
analytic,
initialization_data)
kwargs = (; sys, observed = observedfun, mass_matrix = _M)

return maybe_codegen_scimlfn(
expression, SplitFunction{iip, specialize}, args; kwargs...)
end

struct SemilinearODEProblem{iip, spec} end

@fallback_iip_specialize function SemilinearODEProblem{iip, spec}(
sys::System, op, tspan; check_compatibility = true,
u0_eltype = nothing, expression = Val{false}, callback = nothing,
jac = false, sparse = false, kwargs...) where {iip, spec}
check_complete(sys, SemilinearODEProblem)
check_compatibility && check_compatible_system(SemilinearODEProblem, sys)

A, B, x2, C = semiquadratic_form = calculate_split_form(sys)

semiquadratic_jacobian = nothing
if jac
f1jac, x2jac, Cjac = semiquadratic_jacobian = calculate_semiquadratic_jacobian(
sys, B, x2, C; sparse)
end

sys = add_semilinear_parameters(sys)
linear_matrix_param = unwrap(getproperty(sys, LINEAR_MATRIX_PARAM_NAME))
bilinear_matrix_param = unwrap(getproperty(sys, BILINEAR_MATRIX_PARAM_NAME))
diffcache = unwrap(getproperty(sys, DIFFCACHE_PARAM_NAME))

floatT = calculate_float_type(op, typeof(op))
_u0_eltype = something(u0_eltype, floatT)

guess = copy(guesses(sys))
guess[linear_matrix_param] = fill(NaN, size(A))
guess[bilinear_matrix_param] = fill(NaN, size(B))
@set! sys.guesses = guess
defs = copy(defaults(sys))
defs[linear_matrix_param] = A
defs[bilinear_matrix_param] = B
cachelen = jac ? length(x2jac) : length(x2)
defs[diffcache] = DiffCache(zeros(DiffEqBase.value(_u0_eltype), cachelen))
@set! sys.defaults = defs

f, u0, p = process_SciMLProblem(SemilinearODEFunction{iip, spec}, sys, op;
t = tspan !== nothing ? tspan[1] : tspan, expression, check_compatibility,
semiquadratic_form, semiquadratic_jacobian, jac, sparse, u0_eltype, kwargs...)

kwargs = process_kwargs(
sys; expression, callback, kwargs...)

ptype = getmetadata(sys, ProblemTypeCtx, StandardODEProblem())
args = (; f, u0, tspan, p)
maybe_codegen_scimlproblem(expression, SplitODEProblem{iip}, args; kwargs...)
end

function add_semilinear_parameters(sys::System)
m = length(equations(sys))
n = length(unknowns(sys))
linear_matrix_param = get_linear_matrix_param((m, n))
bilinear_matrix_param = get_bilinear_matrix_param((m, (n^2 + n) ÷ 2))
@assert !is_parameter(sys, linear_matrix_param)
sys = with_additional_constant_parameter(sys, linear_matrix_param)
@assert !is_parameter(sys, bilinear_matrix_param)
sys = with_additional_constant_parameter(sys, bilinear_matrix_param)
@assert !is_parameter(sys, get_diffcache_param(Float64))
diffcache = get_diffcache_param(Float64)
sys = with_additional_nonnumeric_parameter(sys, diffcache)
var_to_name = copy(get_var_to_name(sys))
var_to_name[LINEAR_MATRIX_PARAM_NAME] = linear_matrix_param
var_to_name[BILINEAR_MATRIX_PARAM_NAME] = bilinear_matrix_param
var_to_name[DIFFCACHE_PARAM_NAME] = diffcache
@set! sys.var_to_name = var_to_name
if get_parent(sys) !== nothing
@set! sys.parent = add_semilinear_parameters(get_parent(sys))
end
return sys
end

function check_compatible_system(
T::Union{Type{ODEFunction}, Type{ODEProblem}, Type{DAEFunction},
Type{DAEProblem}, Type{SteadyStateProblem}},
Type{DAEProblem}, Type{SteadyStateProblem}, Type{SemilinearODEFunction},
Type{SemilinearODEProblem}},
sys::System)
check_time_dependent(sys, T)
check_not_dde(sys)
Expand Down
Loading
Loading