Skip to content

Commit

Permalink
refactor: use initialization_data instead of initializeprob, etc.
Browse files Browse the repository at this point in the history
  • Loading branch information
AayushSabharwal committed Oct 30, 2024
1 parent 10bd663 commit 6159778
Showing 1 changed file with 85 additions and 77 deletions.
162 changes: 85 additions & 77 deletions src/scimlfunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -405,8 +405,8 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, WP, TPJ,
O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
O, TCV,
SYS, ID, NLP} <: AbstractODEFunction{iip}
f::F
mass_matrix::TMM
analytic::Ta
Expand All @@ -423,10 +423,7 @@ struct ODEFunction{iip, specialize, F, TMM, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TW
observed::O
colorvec::TCV
sys::SYS
initializeprob::IProb
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
initialization_data::ID
nlprob::NLP
end

Expand Down Expand Up @@ -530,8 +527,8 @@ information on generating the SplitFunction from this symbolic engine.
"""
struct SplitFunction{
iip, specialize, F1, F2, TMM, C, Ta, Tt, TJ, JVP, VJP, JP, WP, SP, TW, TWt,
TPJ, O, TCV, SYS,
IProb, UIProb, IProbMap, IProbPmap, NLP} <: AbstractODEFunction{iip}
TPJ, O,
TCV, SYS, ID, NLP} <: AbstractODEFunction{iip}
f1::F1
f2::F2
mass_matrix::TMM
Expand All @@ -550,10 +547,7 @@ struct SplitFunction{
observed::O
colorvec::TCV
sys::SYS
initializeprob::IProb
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
initialization_data::ID
nlprob::NLP
end

Expand Down Expand Up @@ -1529,7 +1523,7 @@ automatically symbolically generating the Jacobian and more from the
numerically-defined functions.
"""
struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TPJ, O, TCV,
SYS, IProb, UIProb, IProbMap, IProbPmap} <:
SYS, ID} <:
AbstractDAEFunction{iip}
f::F
analytic::Ta
Expand All @@ -1545,10 +1539,7 @@ struct DAEFunction{iip, specialize, F, Ta, Tt, TJ, JVP, VJP, JP, SP, TW, TWt, TP
observed::O
colorvec::TCV
sys::SYS
initializeprob::IProb
update_initializeprob!::UIProb
initializeprobmap::IProbMap
initializeprobpmap::IProbPmap
initialization_data::ID
end

"""
Expand Down Expand Up @@ -2440,6 +2431,8 @@ function ODEFunction{iip, specialize}(f;
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
nlprob = __has_nlprob(f) ? f.nlprob : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing
) where {iip,
specialize
}
Expand Down Expand Up @@ -2486,8 +2479,11 @@ function ODEFunction{iip, specialize}(f;
_f = prepare_function(f)

sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
initdata = reconstruct_initialization_data(
initialization_data, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)

@assert typeof(initializeprob) <:
@assert typeof(initdata.initializeprob) <:
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}

if specialize === NoSpecialize
Expand All @@ -2497,11 +2493,10 @@ function ODEFunction{iip, specialize}(f;
typeof(sparsity), Any, Any, typeof(W_prototype), Any,
Any,
typeof(_colorvec),
typeof(sys), Any, Any, Any, Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
typeof(sys), Any, Any}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap, nlprob)
observed, _colorvec, sys, initdata, nlprob)
elseif specialize === false
ODEFunction{iip, FunctionWrapperSpecialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2510,16 +2505,11 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob),
typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap),
typeof(nlprob)}(_f, mass_matrix,
typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix,
analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap, nlprob)
observed, _colorvec, sys, initdata, nlprob)
else
ODEFunction{iip, specialize,
typeof(_f), typeof(mass_matrix), typeof(analytic), typeof(tgrad),
Expand All @@ -2528,14 +2518,10 @@ function ODEFunction{iip, specialize}(f;
typeof(paramjac),
typeof(observed),
typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap),
typeof(nlprob)}(_f, mass_matrix, analytic, tgrad, jac,
jvp, vjp, jac_prototype, sparsity, Wfact,
typeof(sys), typeof(initdata), typeof(nlprob)}(_f, mass_matrix, analytic, tgrad,
jac, jvp, vjp, jac_prototype, sparsity, Wfact,
Wfact_t, W_prototype, paramjac,
observed, _colorvec, sys, initializeprob, update_initializeprob!, initializeprobmap,
initializeprobpmap, nlprob)
observed, _colorvec, sys, initdata, nlprob)
end
end

Expand All @@ -2552,28 +2538,23 @@ function unwrapped_f(f::ODEFunction, newf = unwrapped_f(f.f))
Any, Any, Any, Any, typeof(f.jac_prototype),
typeof(f.sparsity), Any, Any, Any,
Any, typeof(f.colorvec),
typeof(f.sys), Any, Any, Any, Any, Any}(
typeof(f.sys), Any, Any}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob,
f.update_initializeprob!, f.initializeprobmap,
f.initializeprobpmap, f.nlprob)
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
else
ODEFunction{isinplace(f), specialization(f), typeof(newf), typeof(f.mass_matrix),
typeof(f.analytic), typeof(f.tgrad),
typeof(f.jac), typeof(f.jvp), typeof(f.vjp), typeof(f.jac_prototype),
typeof(f.sparsity), typeof(f.Wfact), typeof(f.Wfact_t), typeof(f.W_prototype),
typeof(f.paramjac),
typeof(f.observed), typeof(f.colorvec),
typeof(f.sys), typeof(f.initializeprob), typeof(f.update_initializeprob!),
typeof(f.initializeprobmap),
typeof(f.initializeprobpmap),
typeof(f.nlprob)}(newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
typeof(f.sys), typeof(f.initialization_data), typeof(f.nlprob)}(
newf, f.mass_matrix, f.analytic, f.tgrad, f.jac,
f.jvp, f.vjp, f.jac_prototype, f.sparsity, f.Wfact,
f.Wfact_t, f.W_prototype, f.paramjac,
f.observed, f.colorvec, f.sys, f.initializeprob, f.update_initializeprob!,
f.initializeprobmap, f.initializeprobpmap, f.nlprob)
f.observed, f.colorvec, f.sys, f.initialization_data, f.nlprob)
end
end

Expand Down Expand Up @@ -2704,8 +2685,8 @@ end

@add_kwonly function SplitFunction(f1, f2, mass_matrix, cache, analytic, tgrad, jac, jvp,
vjp, jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap, nlprob)
observed, colorvec, sys, initializeprob = nothing, update_initializeprob! = nothing,
initializeprobmap = nothing, initializeprobpmap = nothing, initialization_data = nothing, nlprob)
f1 = ODEFunction(f1)
f2 = ODEFunction(f2)

Expand All @@ -2714,17 +2695,20 @@ end
throw(NonconformingFunctionsError(["f2"]))
end

initdata = reconstruct_initialization_data(
initialization_data, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)

SplitFunction{isinplace(f2), FullSpecialize, typeof(f1), typeof(f2),
typeof(mass_matrix),
typeof(cache), typeof(analytic), typeof(tgrad), typeof(jac), typeof(jvp),
typeof(vjp), typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed), typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!), typeof(initializeprobmap),
typeof(initializeprobpmap), typeof(nlprob)}(
typeof(sys), typeof(initdata), typeof(nlprob)}(
f1, f2, mass_matrix,
cache, analytic, tgrad, jac, jvp, vjp,
jac_prototype, W_prototype, sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
initdata, nlprob)
end
function SplitFunction{iip, specialize}(f1, f2;
mass_matrix = __has_mass_matrix(f1) ?
Expand Down Expand Up @@ -2761,37 +2745,39 @@ function SplitFunction{iip, specialize}(f1, f2;
f1.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f1) ? f1.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f1) ? f1.initializeprobpmap : nothing,
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing
nlprob = __has_nlprob(f1) ? f1.nlprob : nothing,
initialization_data = __has_initialization_data(f1) ? f1.initialization_data :
nothing
) where {iip,
specialize
}
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
@assert typeof(initializeprob) <:
initdata = reconstruct_initialization_data(
initialization_data, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)
@assert typeof(initdata.initializeprob) <:
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}

if specialize === NoSpecialize
SplitFunction{iip, specialize, Any, Any, Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any,
Any, Any, Any, Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
Any, Any, Any, Any}(f1, f2, mass_matrix, _func_cache,
analytic,
tgrad, jac, jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac,
observed, colorvec, sys, initializeprob.update_initializeprob!, initializeprobmap,
initializeprobpmap, initializeprobpmap, nlprob)
observed, colorvec, sys, initdata, nlprob)
else
SplitFunction{iip, specialize, typeof(f1), typeof(f2), typeof(mass_matrix),
typeof(_func_cache), typeof(analytic),
typeof(tgrad), typeof(jac), typeof(jvp), typeof(vjp),
typeof(jac_prototype), typeof(W_prototype), typeof(sparsity),
typeof(Wfact), typeof(Wfact_t), typeof(paramjac), typeof(observed),
typeof(colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap), typeof(nlprob)}(f1, f2,
typeof(sys), typeof(initdata), typeof(nlprob)}(f1, f2,
mass_matrix, _func_cache, analytic, tgrad, jac,
jvp, vjp, jac_prototype, W_prototype,
sparsity, Wfact, Wfact_t, paramjac, observed, colorvec, sys,
initializeprob, update_initializeprob!, initializeprobmap, initializeprobpmap, nlprob)
initdata, nlprob)
end
end

Expand Down Expand Up @@ -3420,7 +3406,9 @@ function DAEFunction{iip, specialize}(f;
update_initializeprob! = __has_update_initializeprob!(f) ?
f.update_initializeprob! : nothing,
initializeprobmap = __has_initializeprobmap(f) ? f.initializeprobmap : nothing,
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing) where {
initializeprobpmap = __has_initializeprobpmap(f) ? f.initializeprobpmap : nothing,
initialization_data = __has_initialization_data(f) ? f.initialization_data :
nothing) where {
iip,
specialize
}
Expand Down Expand Up @@ -3452,33 +3440,32 @@ function DAEFunction{iip, specialize}(f;

_f = prepare_function(f)
sys = sys_or_symbolcache(sys, syms, paramsyms, indepsym)
initdata = reconstruct_initialization_data(
initialization_data, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)

@assert typeof(initializeprob) <:
@assert typeof(initdata.initializeprob) <:
Union{Nothing, NonlinearProblem, NonlinearLeastSquaresProblem}

if specialize === NoSpecialize
DAEFunction{iip, specialize, Any, Any, Any,
Any, Any, Any, Any, Any,
Any, Any, Any,
Any, typeof(_colorvec), Any, Any, Any, Any, Any}(_f, analytic, tgrad, jac, jvp,
Any, typeof(_colorvec), Any, Any}(_f, analytic, tgrad, jac, jvp,
vjp, jac_prototype, sparsity,
Wfact, Wfact_t, paramjac, observed,
_colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)
_colorvec, sys, initdata)
else
DAEFunction{iip, specialize, typeof(_f), typeof(analytic), typeof(tgrad),
typeof(jac), typeof(jvp), typeof(vjp), typeof(jac_prototype),
typeof(sparsity), typeof(Wfact), typeof(Wfact_t),
typeof(paramjac),
typeof(observed), typeof(_colorvec),
typeof(sys), typeof(initializeprob), typeof(update_initializeprob!),
typeof(initializeprobmap),
typeof(initializeprobpmap)}(
typeof(sys), typeof(initdata)}(
_f, analytic, tgrad, jac, jvp, vjp,
jac_prototype, sparsity, Wfact, Wfact_t,
paramjac, observed,
_colorvec, sys, initializeprob, update_initializeprob!,
initializeprobmap, initializeprobpmap)
_colorvec, sys, initdata)
end
end

Expand Down Expand Up @@ -4397,6 +4384,14 @@ function sys_or_symbolcache(sys, syms, paramsyms, indepsym = nothing)
return sys
end

function reconstruct_initialization_data(
initdata, initprob, update_initprob!, initprobmap, initprobpmap)
if initdata === nothing && initprob !== nothing
initdata = InitializationData(initprob, update_initprob!, initprobmap, initprobpmap)
end
return initprob
end

########## Existence Functions

# Check that field/property exists (may be nothing)
Expand All @@ -4420,11 +4415,20 @@ __has_colorvec(f) = isdefined(f, :colorvec)
__has_sys(f) = isdefined(f, :sys)
__has_analytic_full(f) = isdefined(f, :analytic_full)
__has_resid_prototype(f) = isdefined(f, :resid_prototype)
__has_initializeprob(f) = isdefined(f, :initializeprob)
__has_update_initializeprob!(f) = isdefined(f, :update_initializeprob!)
__has_initializeprobmap(f) = isdefined(f, :initializeprobmap)
__has_initializeprobpmap(f) = isdefined(f, :initializeprobpmap)
__has_nlprob(f) = isdefined(f, :nlprob)
function __has_initializeprob(f)
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprob)
end
function __has_update_initializeprob!(f)
has_initialization_data(f) && isdefined(f.initialization_data, :update_initializeprob!)
end
function __has_initializeprobmap(f)
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobmap)
end
function __has_initializeprobpmap(f)
has_initialization_data(f) && isdefined(f.initialization_data, :initializeprobpmap)
end
__has_initialization_data(f) = isdefined(f, :initialization_data)

# compatibility
has_invW(f::AbstractSciMLFunction) = false
Expand All @@ -4438,16 +4442,20 @@ has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothin
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
has_sys(f::AbstractSciMLFunction) = __has_sys(f) && f.sys !== nothing
function has_initializeprob(f::AbstractSciMLFunction)
__has_initializeprob(f) && f.initializeprob !== nothing
__has_initializeprob(f) && f.initialization_data.initializeprob !== nothing
end
function has_update_initializeprob!(f::AbstractSciMLFunction)
__has_update_initializeprob!(f) && f.update_initializeprob! !== nothing
__has_update_initializeprob!(f) &&
f.initialization_data.update_initializeprob! !== nothing
end
function has_initializeprobmap(f::AbstractSciMLFunction)
__has_initializeprobmap(f) && f.initializeprobmap !== nothing
__has_initializeprobmap(f) && f.initialization_data.initializeprobmap !== nothing
end
function has_initializeprobpmap(f::AbstractSciMLFunction)
__has_initializeprobpmap(f) && f.initializeprobpmap !== nothing
__has_initializeprobpmap(f) && f.initialization_data.initializeprobpmap !== nothing
end
function has_initialization_data(f::AbstractSciMLFunction)
__has_initialization_data(f) && f.initialization_data !== nothing
end

function has_syms(f::AbstractSciMLFunction)
Expand Down

0 comments on commit 6159778

Please sign in to comment.