1
1
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
2
- G}
2
+ G, SAlg <: GaussAdjoint }
3
3
sol:: S
4
4
p:: pType
5
5
y:: uType
@@ -8,15 +8,17 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG
8
8
f_cache:: rateType
9
9
pJ:: PJT
10
10
paramjac_config:: PJC
11
- sensealg:: GaussAdjoint
11
+ sensealg:: SAlg
12
12
dgdp_cache:: DGP
13
13
dgdp:: G
14
14
end
15
15
16
16
struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache ,
17
17
Alg <: GaussAdjoint ,
18
18
uType, SType, CPS, pType,
19
- fType <: AbstractDiffEqFunction } <: SensitivityFunction
19
+ fType <: DiffEqBase.AbstractDiffEqFunction ,
20
+ GI <: GaussIntegrand ,
21
+ ICB} <: SensitivityFunction
20
22
diffcache:: C
21
23
sensealg:: Alg
22
24
discrete:: Bool
@@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
25
27
checkpoint_sol:: CPS
26
28
prob:: pType
27
29
f:: fType
28
- GaussInt:: GaussIntegrand
30
+ GaussInt:: GI
31
+ integrating_cb:: ICB
29
32
end
30
33
31
34
mutable struct GaussCheckpointSolution{S, I, T, T2}
39
42
function ODEGaussAdjointSensitivityFunction (
40
43
g, sensealg, gaussint, discrete, sol, dgdu, dgdp,
41
44
f, alg,
42
- checkpoints, tols, tstops = nothing ;
45
+ checkpoints, integrating_cb, tols, tstops = nothing ;
43
46
tspan = reverse (sol. prob. tspan))
44
47
checkpointing = ischeckpointing (sensealg, sol)
45
48
(checkpointing && checkpoints === nothing ) &&
@@ -82,7 +85,7 @@ function ODEGaussAdjointSensitivityFunction(
82
85
g, sensealg, discrete, sol, dgdu, dgdp, sol. prob. f, alg;
83
86
quad = true )
84
87
return ODEGaussAdjointSensitivityFunction (diffcache, sensealg, discrete,
85
- y, sol, checkpoint_sol, sol. prob, f, gaussint)
88
+ y, sol, checkpoint_sol, sol. prob, f, gaussint, integrating_cb )
86
89
end
87
90
88
91
function Gaussfindcursor (intervals, t)
@@ -200,7 +203,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true
200
203
end
201
204
202
205
@noinline function ODEAdjointProblem (sol, sensealg:: GaussAdjoint , alg,
203
- GaussInt:: GaussIntegrand ,
206
+ GaussInt:: GaussIntegrand , integrating_cb,
204
207
t = nothing ,
205
208
dgdu_discrete:: DG1 = nothing ,
206
209
dgdp_discrete:: DG2 = nothing ,
273
276
λ = zero (u0)
274
277
end
275
278
sense = ODEGaussAdjointSensitivityFunction (g, sensealg, GaussInt, discrete, sol,
276
- dgdu_continuous, dgdp_continuous, f, alg, checkpoints,
279
+ dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb,
277
280
(reltol = reltol, abstol = abstol), tstops, tspan = tspan)
278
281
279
282
init_cb = (discrete || dgdu_discrete != = nothing ) # && tspan[1] == t[end]
@@ -577,7 +580,8 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
577
580
578
581
if sol. prob isa ODEProblem
579
582
adj_prob, cb2, rcb = ODEAdjointProblem (
580
- sol, sensealg, alg, integrand, t, dgdu_discrete,
583
+ sol, sensealg, alg, integrand, integrating_cb,
584
+ t, dgdu_discrete,
581
585
dgdp_discrete,
582
586
dgdu_continuous, dgdp_continuous, g, Val (true );
583
587
checkpoints = checkpoints,
0 commit comments