Skip to content

Commit e7c73e3

Browse files
Merge pull request #9 from SciML/observed_indexing
Allow hooks for observed symbol indexing
2 parents 8314ffb + 730949d commit e7c73e3

9 files changed

+201
-66
lines changed

src/SciMLBase.jl

+1
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using DocStringExtensions
55
using LinearAlgebra
66
using Statistics
77
using Distributed
8+
using StaticArrays
89

910
import Logging, ArrayInterface
1011
import IteratorInterfaceExtensions

src/ensemble/ensemble_analysis.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module EnsembleAnalysis
22

3-
using SciMLBase, StaticArrays, Statistics, RecursiveArrayTools
3+
using SciMLBase, Statistics, RecursiveArrayTools
44

55
# Getters
66
get_timestep(sim,i) = (getindex(sol,i) for sol in sim)

src/operators/basic_operators.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ setval!(α::DiffEqScalar, val) = (α.val = val; α)
5555
isconstant::DiffEqScalar) = α.update_func == DEFAULT_UPDATE_FUNC
5656

5757
for op in (:*, :/, :\)
58-
@eval Base.$op::DiffEqScalar, x::Union{AbstractArray,Number}) = $op.val, x)
59-
@eval Base.$op(x::Union{AbstractArray,Number}, α::DiffEqScalar) = $op(x, α.val)
58+
for T in (:AbstractArray, :Number)
59+
@eval Base.$op::DiffEqScalar, x::$T) = $op.val, x)
60+
@eval Base.$op(x::$T, α::DiffEqScalar) = $op(x, α.val)
61+
end
6062
@eval Base.$op(x::DiffEqScalar, y::DiffEqScalar) = $op(x.val, y.val)
6163
end
6264

@@ -108,6 +110,7 @@ Base.iterate(L::DiffEqArrayOperator,args...) = iterate(L.A,args...)
108110
Base.axes(L::DiffEqArrayOperator) = axes(L.A)
109111
Base.IndexStyle(::Type{<:DiffEqArrayOperator{T,AType}}) where {T,AType} = Base.IndexStyle(AType)
110112
Base.copyto!(L::DiffEqArrayOperator, rhs) = (copyto!(L.A, rhs); L)
113+
Base.copyto!(L::DiffEqArrayOperator, rhs::Base.Broadcast.Broadcasted{<:StaticArrays.StaticArrayStyle}) = (copyto!(L.A, rhs); L)
111114
Base.Broadcast.broadcastable(L::DiffEqArrayOperator) = L
112115
Base.ndims(::Type{<:DiffEqArrayOperator{T,AType}}) where {T,AType} = ndims(AType)
113116
ArrayInterface.issingular(L::DiffEqArrayOperator) = ArrayInterface.issingular(L.A)

src/operators/common_defaults.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@ LinearAlgebra.opnorm(L::AbstractDiffEqLinearOperator, p::Real=2) = opnorm(conver
1010
Base.@propagate_inbounds Base.getindex(L::AbstractDiffEqLinearOperator, I::Vararg{Any,N}) where {N} = convert(AbstractMatrix,L)[I...]
1111
Base.getindex(L::AbstractDiffEqLinearOperator, I::Vararg{Int, N}) where {N} =
1212
convert(AbstractMatrix,L)[I...]
13-
for op in (:*, :/, :\)
14-
@eval Base.$op(L::AbstractDiffEqLinearOperator, x::Union{AbstractArray,Number}) = $op(convert(AbstractMatrix,L), x)
15-
@eval Base.$op(x::Union{AbstractVecOrMat,Number}, L::AbstractDiffEqLinearOperator) = $op(x, convert(AbstractMatrix,L))
13+
for op in (:*, :/, :\), T in (:AbstractArray, :Number)
14+
@eval Base.$op(L::AbstractDiffEqLinearOperator, x::$T) = $op(convert(AbstractMatrix,L), x)
15+
@eval Base.$op(x::$T, L::AbstractDiffEqLinearOperator) = $op(x, convert(AbstractMatrix,L))
1616
end
1717
LinearAlgebra.mul!(Y::AbstractArray, L::AbstractDiffEqLinearOperator, B::AbstractArray) =
1818
mul!(Y, convert(AbstractMatrix,L), B)

src/problems/bvp_problems.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,11 @@ end
4040

4141
# convenience interfaces:
4242
# Allow any previous timeseries solution
43-
function BVProblem(f,bc,sol::T,tspan,p=NullParameters();kwargs...) where {T<:AbstractTimeseriesSolution}
43+
function BVProblem(f::AbstractODEFunction,bc,sol::T,tspan::Tuple,p=NullParameters();kwargs...) where {T<:AbstractTimeseriesSolution}
4444
BVProblem(f,bc,sol.u,tspan,p)
4545
end
4646
# Allow a function of time for the initial guess
47-
function BVProblem(f,bc,initialGuess::T,tspan::AbstractVector,p=NullParameters();kwargs...) where {T}
47+
function BVProblem(f::AbstractODEFunction,bc,initialGuess,tspan::AbstractVector,p=NullParameters();kwargs...)
4848
u0 = [ initialGuess( i ) for i in tspan]
4949
BVProblem(f,bc,u0,(tspan[1],tspan[end]),p)
5050
end

src/problems/discrete_problems.jl

+2-12
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function DiscreteProblem(f::AbstractDiscreteFunction,u0,tspan::Tuple,p=NullParam
5555
DiscreteProblem{isinplace(f)}(f,u0,tspan,p;kwargs...)
5656
end
5757

58-
function DiscreteProblem(f,u0,tspan::Tuple,p=NullParameters();kwargs...)
58+
function DiscreteProblem(f::Base.Callable,u0,tspan::Tuple,p=NullParameters();kwargs...)
5959
iip = isinplace(f,4)
6060
DiscreteProblem(convert(DiscreteFunction{iip},f),u0,tspan,p;kwargs...)
6161
end
@@ -65,17 +65,7 @@ $(SIGNATURES)
6565
6666
Define a discrete problem with the identity map.
6767
"""
68-
function DiscreteProblem(u0,tspan::Tuple,p::Tuple;kwargs...)
69-
iip = typeof(u0) <: AbstractArray
70-
if iip
71-
f = DISCRETE_INPLACE_DEFAULT
72-
else
73-
f = DISCRETE_OUTOFPLACE_DEFAULT
74-
end
75-
DiscreteProblem(f,u0,tspan,p;kwargs...)
76-
end
77-
78-
function DiscreteProblem(u0,tspan::Tuple,p=NullParameters();kwargs...)
68+
function DiscreteProblem(u0::Union{AbstractArray,Number},tspan::Tuple,p=NullParameters();kwargs...)
7969
iip = typeof(u0) <: AbstractArray
8070
if iip
8171
f = DISCRETE_INPLACE_DEFAULT

src/scimlfunctions.jl

+67-7
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
const RECOMPILE_BY_DEFAULT = true
22

3+
function DEFAULT_OBSERVED(sym,u,p,t)
4+
error("Indexing symbol $sym is unknown.")
5+
end
6+
37
Base.summary(prob::AbstractSciMLFunction) = string(TYPE_COLOR, nameof(typeof(prob)),
48
NO_COLOR, ". In-place: ",
59
TYPE_COLOR, isinplace(prob),
@@ -18,7 +22,7 @@ abstract type AbstractODEFunction{iip} <: AbstractDiffEqFunction{iip} end
1822
"""
1923
$(TYPEDEF)
2024
"""
21-
struct ODEFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,TCV} <: AbstractODEFunction{iip}
25+
struct ODEFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,S2,O,TCV} <: AbstractODEFunction{iip}
2226
f::F
2327
mass_matrix::TMM
2428
analytic::Ta
@@ -32,6 +36,8 @@ struct ODEFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,TCV} <: Abstrac
3236
Wfact_t::TWt
3337
paramjac::TPJ
3438
syms::S
39+
indepsym::S2
40+
observed::O
3541
colorvec::TCV
3642
end
3743

@@ -357,6 +363,8 @@ function ODEFunction{iip,true}(f;
357363
Wfact_t=nothing,
358364
paramjac = nothing,
359365
syms = nothing,
366+
indepsym = nothing,
367+
observed = DEFAULT_OBSERVED,
360368
colorvec = nothing) where iip
361369

362370
if mass_matrix == I && typeof(f) <: Tuple
@@ -380,10 +388,11 @@ function ODEFunction{iip,true}(f;
380388
ODEFunction{iip,
381389
typeof(f), typeof(mass_matrix), typeof(analytic), typeof(tgrad), typeof(jac),
382390
typeof(jvp), typeof(vjp), typeof(jac_prototype), typeof(sparsity), typeof(Wfact),
383-
typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(_colorvec)}(
391+
typeof(Wfact_t), typeof(paramjac), typeof(syms), typeof(indepsym),
392+
typeof(observed), typeof(_colorvec)}(
384393
f, mass_matrix, analytic, tgrad, jac,
385394
jvp, vjp, jac_prototype, sparsity, Wfact,
386-
Wfact_t, paramjac, syms, _colorvec)
395+
Wfact_t, paramjac, syms, indepsym, observed, _colorvec)
387396
end
388397
function ODEFunction{iip,false}(f;
389398
mass_matrix=I,
@@ -398,6 +407,8 @@ function ODEFunction{iip,false}(f;
398407
Wfact_t=nothing,
399408
paramjac = nothing,
400409
syms = nothing,
410+
indepsym = nothing,
411+
observed = DEFAULT_OBSERVED,
401412
colorvec = nothing) where iip
402413

403414
if jac === nothing && isa(jac_prototype, AbstractDiffEqLinearOperator)
@@ -417,10 +428,10 @@ function ODEFunction{iip,false}(f;
417428
ODEFunction{iip,
418429
Any, Any, Any, Any, Any,
419430
Any, Any, Any, Any, Any,
420-
Any, Any, typeof(syms), typeof(_colorvec)}(
431+
Any, Any, typeof(syms), typeof(indepsym), Any, typeof(_colorvec)}(
421432
f, mass_matrix, analytic, tgrad, jac,
422433
jvp, vjp, jac_prototype, sparsity, Wfact,
423-
Wfact_t, paramjac, syms, _colorvec)
434+
Wfact_t, paramjac, syms, indepsym, observed, _colorvec)
424435
end
425436
ODEFunction{iip}(f; kwargs...) where iip = ODEFunction{iip,RECOMPILE_BY_DEFAULT}(f; kwargs...)
426437
ODEFunction{iip}(f::ODEFunction; kwargs...) where iip = f
@@ -1094,6 +1105,8 @@ __has_Wfact(f) = isdefined(f, :Wfact)
10941105
__has_Wfact_t(f) = isdefined(f, :Wfact_t)
10951106
__has_paramjac(f) = isdefined(f, :paramjac)
10961107
__has_syms(f) = isdefined(f, :syms)
1108+
__has_indepsym(f) = isdefined(f, :indepsym)
1109+
__has_observed(f) = isdefined(f, :observed)
10971110
__has_analytic(f) = isdefined(f, :analytic)
10981111
__has_colorvec(f) = isdefined(f, :colorvec)
10991112

@@ -1108,6 +1121,8 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
11081121
has_Wfact_t(f::AbstractSciMLFunction) = __has_Wfact_t(f) && f.Wfact_t !== nothing
11091122
has_paramjac(f::AbstractSciMLFunction) = __has_paramjac(f) && f.paramjac !== nothing
11101123
has_syms(f::AbstractSciMLFunction) = __has_syms(f) && f.syms !== nothing
1124+
has_indepsym(f::AbstractSciMLFunction) = __has_indepsym(f) && f.indepsym !== nothing
1125+
has_observed(f::AbstractSciMLFunction) = __has_observed(f) && f.observed !== DEFAULT_OBSERVED && f.observed !== nothing
11111126
has_colorvec(f::AbstractSciMLFunction) = __has_colorvec(f) && f.colorvec !== nothing
11121127

11131128
# TODO: find an appropriate way to check `has_*`
@@ -1203,13 +1218,27 @@ function Base.convert(::Type{ODEFunction}, f)
12031218
else
12041219
syms = nothing
12051220
end
1221+
1222+
if __has_indepsym(f)
1223+
indepsym = f.indepsym
1224+
else
1225+
indepsym = nothing
1226+
end
1227+
1228+
if __has_observed(f)
1229+
observed = f.observed
1230+
else
1231+
observed = DEFAULT_OBSERVED
1232+
end
1233+
12061234
if __has_colorvec(f)
12071235
colorvec = f.colorvec
12081236
else
12091237
colorvec = nothing
12101238
end
12111239
ODEFunction(f;analytic=analytic,tgrad=tgrad,jac=jac,jvp=jvp,vjp=vjp,Wfact=Wfact,
1212-
Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec)
1240+
Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,indepsym=indepsym,
1241+
observed=observed,colorvec=colorvec)
12131242
end
12141243
function Base.convert(::Type{ODEFunction{iip}},f) where iip
12151244
if __has_analytic(f)
@@ -1257,13 +1286,27 @@ function Base.convert(::Type{ODEFunction{iip}},f) where iip
12571286
else
12581287
syms = nothing
12591288
end
1289+
1290+
if __has_indepsym(f)
1291+
indepsym = f.indepsym
1292+
else
1293+
indepsym = nothing
1294+
end
1295+
1296+
if __has_observed(f)
1297+
observed = f.observed
1298+
else
1299+
observed = DEFAULT_OBSERVED
1300+
end
1301+
12601302
if __has_colorvec(f)
12611303
colorvec = f.colorvec
12621304
else
12631305
colorvec = nothing
12641306
end
12651307
ODEFunction{iip,RECOMPILE_BY_DEFAULT}(f;analytic=analytic,tgrad=tgrad,jac=jac,jvp=jvp,vjp=vjp,Wfact=Wfact,
1266-
Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,colorvec=colorvec)
1308+
Wfact_t=Wfact_t,paramjac=paramjac,syms=syms,indepsym=indepsym,
1309+
observed=observed,colorvec=colorvec)
12671310
end
12681311

12691312
function Base.convert(::Type{DiscreteFunction},f)
@@ -1899,3 +1942,20 @@ function Base.convert(::Type{IncrementingODEFunction}, f)
18991942
end
19001943

19011944
(f::IncrementingODEFunction)(args...;kwargs...) = f.f(args...;kwargs...)
1945+
1946+
for S in [
1947+
:ODEFunction
1948+
:DiscreteFunction
1949+
:DAEFunction
1950+
:DDEFunction
1951+
:SDEFunction
1952+
:RODEFunction
1953+
:SDDEFunction
1954+
:NonlinearFunction
1955+
:IncrementingODEFunction
1956+
]
1957+
@eval begin
1958+
Base.convert(::Type{$S}, x::$S) = x
1959+
Base.convert(::Type{$S{iip}}, x::T) where {T<:$S{iip}} where iip = x
1960+
end
1961+
end

0 commit comments

Comments
 (0)