1
1
const RECOMPILE_BY_DEFAULT = true
2
2
3
+ function DEFAULT_OBSERVED (sym,u,p,t)
4
+ error (" Indexing symbol $sym is unknown." )
5
+ end
6
+
3
7
Base. summary (prob:: AbstractSciMLFunction ) = string (TYPE_COLOR, nameof (typeof (prob)),
4
8
NO_COLOR, " . In-place: " ,
5
9
TYPE_COLOR, isinplace (prob),
@@ -18,7 +22,7 @@ abstract type AbstractODEFunction{iip} <: AbstractDiffEqFunction{iip} end
18
22
"""
19
23
$(TYPEDEF)
20
24
"""
21
- struct ODEFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,TCV} <: AbstractODEFunction{iip}
25
+ struct ODEFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,S2,O, TCV} <: AbstractODEFunction{iip}
22
26
f:: F
23
27
mass_matrix:: TMM
24
28
analytic:: Ta
@@ -32,6 +36,8 @@ struct ODEFunction{iip,F,TMM,Ta,Tt,TJ,JVP,VJP,JP,SP,TW,TWt,TPJ,S,TCV} <: Abstrac
32
36
Wfact_t:: TWt
33
37
paramjac:: TPJ
34
38
syms:: S
39
+ indepsym:: S2
40
+ observed:: O
35
41
colorvec:: TCV
36
42
end
37
43
@@ -357,6 +363,8 @@ function ODEFunction{iip,true}(f;
357
363
Wfact_t= nothing ,
358
364
paramjac = nothing ,
359
365
syms = nothing ,
366
+ indepsym = nothing ,
367
+ observed = DEFAULT_OBSERVED,
360
368
colorvec = nothing ) where iip
361
369
362
370
if mass_matrix == I && typeof (f) <: Tuple
@@ -380,10 +388,11 @@ function ODEFunction{iip,true}(f;
380
388
ODEFunction{iip,
381
389
typeof (f), typeof (mass_matrix), typeof (analytic), typeof (tgrad), typeof (jac),
382
390
typeof (jvp), typeof (vjp), typeof (jac_prototype), typeof (sparsity), typeof (Wfact),
383
- typeof (Wfact_t), typeof (paramjac), typeof (syms), typeof (_colorvec)}(
391
+ typeof (Wfact_t), typeof (paramjac), typeof (syms), typeof (indepsym),
392
+ typeof (observed), typeof (_colorvec)}(
384
393
f, mass_matrix, analytic, tgrad, jac,
385
394
jvp, vjp, jac_prototype, sparsity, Wfact,
386
- Wfact_t, paramjac, syms, _colorvec)
395
+ Wfact_t, paramjac, syms, indepsym, observed, _colorvec)
387
396
end
388
397
function ODEFunction {iip,false} (f;
389
398
mass_matrix= I,
@@ -398,6 +407,8 @@ function ODEFunction{iip,false}(f;
398
407
Wfact_t= nothing ,
399
408
paramjac = nothing ,
400
409
syms = nothing ,
410
+ indepsym = nothing ,
411
+ observed = DEFAULT_OBSERVED,
401
412
colorvec = nothing ) where iip
402
413
403
414
if jac === nothing && isa (jac_prototype, AbstractDiffEqLinearOperator)
@@ -417,10 +428,10 @@ function ODEFunction{iip,false}(f;
417
428
ODEFunction{iip,
418
429
Any, Any, Any, Any, Any,
419
430
Any, Any, Any, Any, Any,
420
- Any, Any, typeof (syms), typeof (_colorvec)}(
431
+ Any, Any, typeof (syms), typeof (indepsym), Any, typeof ( _colorvec)}(
421
432
f, mass_matrix, analytic, tgrad, jac,
422
433
jvp, vjp, jac_prototype, sparsity, Wfact,
423
- Wfact_t, paramjac, syms, _colorvec)
434
+ Wfact_t, paramjac, syms, indepsym, observed, _colorvec)
424
435
end
425
436
ODEFunction {iip} (f; kwargs... ) where iip = ODEFunction {iip,RECOMPILE_BY_DEFAULT} (f; kwargs... )
426
437
ODEFunction {iip} (f:: ODEFunction ; kwargs... ) where iip = f
@@ -1094,6 +1105,8 @@ __has_Wfact(f) = isdefined(f, :Wfact)
1094
1105
__has_Wfact_t (f) = isdefined (f, :Wfact_t )
1095
1106
__has_paramjac (f) = isdefined (f, :paramjac )
1096
1107
__has_syms (f) = isdefined (f, :syms )
1108
+ __has_indepsym (f) = isdefined (f, :indepsym )
1109
+ __has_observed (f) = isdefined (f, :observed )
1097
1110
__has_analytic (f) = isdefined (f, :analytic )
1098
1111
__has_colorvec (f) = isdefined (f, :colorvec )
1099
1112
@@ -1108,6 +1121,8 @@ has_Wfact(f::AbstractSciMLFunction) = __has_Wfact(f) && f.Wfact !== nothing
1108
1121
has_Wfact_t (f:: AbstractSciMLFunction ) = __has_Wfact_t (f) && f. Wfact_t != = nothing
1109
1122
has_paramjac (f:: AbstractSciMLFunction ) = __has_paramjac (f) && f. paramjac != = nothing
1110
1123
has_syms (f:: AbstractSciMLFunction ) = __has_syms (f) && f. syms != = nothing
1124
+ has_indepsym (f:: AbstractSciMLFunction ) = __has_indepsym (f) && f. indepsym != = nothing
1125
+ has_observed (f:: AbstractSciMLFunction ) = __has_observed (f) && f. observed != = DEFAULT_OBSERVED && f. observed != = nothing
1111
1126
has_colorvec (f:: AbstractSciMLFunction ) = __has_colorvec (f) && f. colorvec != = nothing
1112
1127
1113
1128
# TODO : find an appropriate way to check `has_*`
@@ -1203,13 +1218,27 @@ function Base.convert(::Type{ODEFunction}, f)
1203
1218
else
1204
1219
syms = nothing
1205
1220
end
1221
+
1222
+ if __has_indepsym (f)
1223
+ indepsym = f. indepsym
1224
+ else
1225
+ indepsym = nothing
1226
+ end
1227
+
1228
+ if __has_observed (f)
1229
+ observed = f. observed
1230
+ else
1231
+ observed = DEFAULT_OBSERVED
1232
+ end
1233
+
1206
1234
if __has_colorvec (f)
1207
1235
colorvec = f. colorvec
1208
1236
else
1209
1237
colorvec = nothing
1210
1238
end
1211
1239
ODEFunction (f;analytic= analytic,tgrad= tgrad,jac= jac,jvp= jvp,vjp= vjp,Wfact= Wfact,
1212
- Wfact_t= Wfact_t,paramjac= paramjac,syms= syms,colorvec= colorvec)
1240
+ Wfact_t= Wfact_t,paramjac= paramjac,syms= syms,indepsym= indepsym,
1241
+ observed= observed,colorvec= colorvec)
1213
1242
end
1214
1243
function Base. convert (:: Type{ODEFunction{iip}} ,f) where iip
1215
1244
if __has_analytic (f)
@@ -1257,13 +1286,27 @@ function Base.convert(::Type{ODEFunction{iip}},f) where iip
1257
1286
else
1258
1287
syms = nothing
1259
1288
end
1289
+
1290
+ if __has_indepsym (f)
1291
+ indepsym = f. indepsym
1292
+ else
1293
+ indepsym = nothing
1294
+ end
1295
+
1296
+ if __has_observed (f)
1297
+ observed = f. observed
1298
+ else
1299
+ observed = DEFAULT_OBSERVED
1300
+ end
1301
+
1260
1302
if __has_colorvec (f)
1261
1303
colorvec = f. colorvec
1262
1304
else
1263
1305
colorvec = nothing
1264
1306
end
1265
1307
ODEFunction {iip,RECOMPILE_BY_DEFAULT} (f;analytic= analytic,tgrad= tgrad,jac= jac,jvp= jvp,vjp= vjp,Wfact= Wfact,
1266
- Wfact_t= Wfact_t,paramjac= paramjac,syms= syms,colorvec= colorvec)
1308
+ Wfact_t= Wfact_t,paramjac= paramjac,syms= syms,indepsym= indepsym,
1309
+ observed= observed,colorvec= colorvec)
1267
1310
end
1268
1311
1269
1312
function Base. convert (:: Type{DiscreteFunction} ,f)
@@ -1899,3 +1942,20 @@ function Base.convert(::Type{IncrementingODEFunction}, f)
1899
1942
end
1900
1943
1901
1944
(f:: IncrementingODEFunction )(args... ;kwargs... ) = f. f (args... ;kwargs... )
1945
+
1946
+ for S in [
1947
+ :ODEFunction
1948
+ :DiscreteFunction
1949
+ :DAEFunction
1950
+ :DDEFunction
1951
+ :SDEFunction
1952
+ :RODEFunction
1953
+ :SDDEFunction
1954
+ :NonlinearFunction
1955
+ :IncrementingODEFunction
1956
+ ]
1957
+ @eval begin
1958
+ Base. convert (:: Type{$S} , x:: $S ) = x
1959
+ Base. convert (:: Type{$S{iip}} , x:: T ) where {T<: $S{iip} } where iip = x
1960
+ end
1961
+ end
0 commit comments