Skip to content

Commit 11e61fb

Browse files
Almost there
1 parent f29d937 commit 11e61fb

File tree

2 files changed

+29
-12
lines changed

2 files changed

+29
-12
lines changed

src/callback_tracking.jl

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,11 @@ function _setup_reverse_callbacks(
272272
du = first(get_tmp_cache(integrator))
273273
λ, grad, y, dλ, dgrad, dy = split_states(du, integrator.u, integrator.t, S)
274274

275+
if sensealg isa GaussAdjoint
276+
dgrad = integrator.f.f.integrating_cb.affect!.accumulation_cache
277+
recursive_copyto!(dgrad, 0)
278+
end
279+
275280
# if save_positions[2] = false, then the right limit is not saved. Thus, for
276281
# the QuadratureAdjoint we would need to lift y from the left to the right limit.
277282
# However, one also needs to update dgrad later on.
@@ -339,7 +344,10 @@ function _setup_reverse_callbacks(
339344
vecjacobian!(dλ, y, λ, integrator.p, integrator.t, fakeS;
340345
dgrad = dgrad, dy = dy)
341346

342-
dgrad !== nothing && (dgrad .*= -1)
347+
if dgrad !== nothing && !(sensealg isa QuadratureAdjoint)
348+
dgrad .*= -1
349+
end
350+
343351
if cb isa Union{ContinuousCallback, VectorContinuousCallback}
344352
# second correction to correct for left limit
345353
(; Lu_left) = correction
@@ -358,8 +366,13 @@ function _setup_reverse_callbacks(
358366

359367
λ .=
360368

361-
if !(sensealg isa QuadratureAdjoint) && !(sensealg isa GaussAdjoint)
362-
grad .-= dgrad
369+
if sensealg isa GaussAdjoint
370+
@assert integrator.f.f isa ODEGaussAdjointSensitivityFunction
371+
integrator.f.f.integrating_cb.affect!.integrand_values.integrand .= dgrad
372+
373+
#recursive_add!(integrator.f.f.integrating_cb.affect!.integrand_values.integrand,dgrad)
374+
elseif !(sensealg isa QuadratureAdjoint)
375+
grad .= dgrad
363376
end
364377
end
365378

src/gauss_adjoint.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DGP,
2-
G}
2+
G, SAlg <: GaussAdjoint}
33
sol::S
44
p::pType
55
y::uType
@@ -8,15 +8,17 @@ mutable struct GaussIntegrand{pType, uType, lType, rateType, S, PF, PJC, PJT, DG
88
f_cache::rateType
99
pJ::PJT
1010
paramjac_config::PJC
11-
sensealg::GaussAdjoint
11+
sensealg::SAlg
1212
dgdp_cache::DGP
1313
dgdp::G
1414
end
1515

1616
struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
1717
Alg <: GaussAdjoint,
1818
uType, SType, CPS, pType,
19-
fType <: AbstractDiffEqFunction} <: SensitivityFunction
19+
fType <: DiffEqBase.AbstractDiffEqFunction,
20+
GI <: GaussIntegrand,
21+
ICB} <: SensitivityFunction
2022
diffcache::C
2123
sensealg::Alg
2224
discrete::Bool
@@ -25,7 +27,8 @@ struct ODEGaussAdjointSensitivityFunction{C <: AdjointDiffCache,
2527
checkpoint_sol::CPS
2628
prob::pType
2729
f::fType
28-
GaussInt::GaussIntegrand
30+
GaussInt::GI
31+
integrating_cb::ICB
2932
end
3033

3134
mutable struct GaussCheckpointSolution{S, I, T, T2}
@@ -39,7 +42,7 @@ end
3942
function ODEGaussAdjointSensitivityFunction(
4043
g, sensealg, gaussint, discrete, sol, dgdu, dgdp,
4144
f, alg,
42-
checkpoints, tols, tstops = nothing;
45+
checkpoints, integrating_cb, tols, tstops = nothing;
4346
tspan = reverse(sol.prob.tspan))
4447
checkpointing = ischeckpointing(sensealg, sol)
4548
(checkpointing && checkpoints === nothing) &&
@@ -82,7 +85,7 @@ function ODEGaussAdjointSensitivityFunction(
8285
g, sensealg, discrete, sol, dgdu, dgdp, sol.prob.f, alg;
8386
quad = true)
8487
return ODEGaussAdjointSensitivityFunction(diffcache, sensealg, discrete,
85-
y, sol, checkpoint_sol, sol.prob, f, gaussint)
88+
y, sol, checkpoint_sol, sol.prob, f, gaussint, integrating_cb)
8689
end
8790

8891
function Gaussfindcursor(intervals, t)
@@ -200,7 +203,7 @@ function split_states(u, t, S::ODEGaussAdjointSensitivityFunction; update = true
200203
end
201204

202205
@noinline function ODEAdjointProblem(sol, sensealg::GaussAdjoint, alg,
203-
GaussInt::GaussIntegrand,
206+
GaussInt::GaussIntegrand, integrating_cb,
204207
t = nothing,
205208
dgdu_discrete::DG1 = nothing,
206209
dgdp_discrete::DG2 = nothing,
@@ -273,7 +276,7 @@ end
273276
λ = zero(u0)
274277
end
275278
sense = ODEGaussAdjointSensitivityFunction(g, sensealg, GaussInt, discrete, sol,
276-
dgdu_continuous, dgdp_continuous, f, alg, checkpoints,
279+
dgdu_continuous, dgdp_continuous, f, alg, checkpoints, integrating_cb,
277280
(reltol = reltol, abstol = abstol), tstops, tspan = tspan)
278281

279282
init_cb = (discrete || dgdu_discrete !== nothing) # && tspan[1] == t[end]
@@ -577,7 +580,8 @@ function _adjoint_sensitivities(sol, sensealg::GaussAdjoint, alg; t = nothing,
577580

578581
if sol.prob isa ODEProblem
579582
adj_prob, cb2, rcb = ODEAdjointProblem(
580-
sol, sensealg, alg, integrand, t, dgdu_discrete,
583+
sol, sensealg, alg, integrand, integrating_cb,
584+
t, dgdu_discrete,
581585
dgdp_discrete,
582586
dgdu_continuous, dgdp_continuous, g, Val(true);
583587
checkpoints = checkpoints,

0 commit comments

Comments
 (0)