-
Notifications
You must be signed in to change notification settings - Fork 72
Closed
Description
Example from SciML/SciMLSensitivity.jl#1067:
using OrdinaryDiffEq, Zygote, SciMLSensitivity
N0 = [0.0] # initial population
p = [100.0, 50.0] # steady-state pop., M
tspan = (0.0, 10.0) # integration time
f(D, u, p, t) = (D[1] = p[1] - u[1]) # system
prob = ODEProblem(f, N0, tspan, p)
# at time tinject1 we inject M1 cells
tinject = 8.0
condition(u, t, integrator) = t == tinject
affect(integrator) = integrator.u[1] += integrator.p[2]
cb = DiscreteCallback(condition, affect)
function loss(p)
_prob = remake(prob, p = p)
_sol = solve(_prob, Tsit5(); callback = cb,
abstol = 1e-14, reltol = 1e-14, tstops = [tinject],
sensealg = BacksolveAdjoint(autojacvec = EnzymeVJP()))
_sol.u[end][1]
end
gZy = Zygote.gradient(loss, p)[1]
Throws:
ERROR: setfield!: immutable struct of type #136#140 cannot be changed
Stacktrace:
[1] make_zero!
@ ~/.julia/packages/Enzyme/SiyIj/src/compiler.jl:1601 [inlined]
[2] make_zero!
@ ~/.julia/packages/Enzyme/SiyIj/src/compiler.jl:1576 [inlined]
[3] _vecjacobian!(dλ::SubArray{…}, y::Vector{…}, λ::SubArray{…}, p::Vector{…}, t::Float64, S::SciMLSensitivity.CallbackSensitivityFunction{…}, isautojacvec::EnzymeVJP, dgrad::SubArray{…}, dy::SubArray{…}, W::Nothing)
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:710
[4] #vecjacobian!#18
@ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:232 [inlined]
[5] vecjacobian!
@ ~/.julia/dev/SciMLSensitivity/src/derivative_wrappers.jl:229 [inlined]
[6] (::SciMLSensitivity.var"#affect!#272"{…})(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/callback_tracking.jl:339
[7] #111
@ ~/.julia/packages/DiffEqCallbacks/9fKPq/src/preset_time.jl:58 [inlined]
[8] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:613 [inlined]
[9] apply_discrete_callback! (repeats 2 times)
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:635 [inlined]
[10] apply_discrete_callback!
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/callbacks.jl:628 [inlined]
[11] handle_callbacks!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:349
[12] _loopfooter!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:254
[13] loopfooter!
@ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/integrators/integrator_utils.jl:207 [inlined]
[14] solve!(integrator::OrdinaryDiffEq.ODEIntegrator{…})
@ OrdinaryDiffEq ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:558
[15] #__solve#560
@ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:7 [inlined]
[16] __solve
@ ~/.julia/packages/OrdinaryDiffEq/HQ92J/src/solve.jl:1 [inlined]
[17] solve_call(_prob::ODEProblem{…}, args::Tsit5{…}; merge_callbacks::Bool, kwargshandle::Nothing, kwargs::@Kwargs{…})
@ DiffEqBase ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:612
[18] solve_call
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:569 [inlined]
[19] #solve_up#53
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1080 [inlined]
[20] solve_up
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1066 [inlined]
[21] #solve#51
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
[22] _adjoint_sensitivities(sol::ODESolution{…}, sensealg::BacksolveAdjoint{…}, alg::Tsit5{…}; t::Vector{…}, dgdu_discrete::Function, dgdp_discrete::Nothing, dgdu_continuous::Nothing, dgdp_continuous::Nothing, g::Nothing, abstol::Float64, reltol::Float64, checkpoints::Vector{…}, corfunc_analytical::Nothing, callback::CallbackSet{…}, kwargs::@Kwargs{…})
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:448
[23] _adjoint_sensitivities
@ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:405 [inlined]
[24] #adjoint_sensitivities#63
@ ~/.julia/dev/SciMLSensitivity/src/sensitivity_interface.jl:401 [inlined]
[25] (::SciMLSensitivity.var"#adjoint_sensitivity_backpass#310"{…})(Δ::ODESolution{…})
@ SciMLSensitivity ~/.julia/dev/SciMLSensitivity/src/concrete_solve.jl:619
[26] ZBack
@ ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:211 [inlined]
[27] (::Zygote.var"#kw_zpullback#53"{…})(dy::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/chainrules.jl:237
[28] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[29] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[30] #solve#51
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:1003 [inlined]
[31] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[32] #291
@ ~/.julia/packages/Zygote/nsBv0/src/lib/lib.jl:206 [inlined]
[33] (::Zygote.var"#2169#back#293"{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/ZygoteRules/M4xmc/src/adjoint.jl:72
[34] solve
@ ~/.julia/packages/DiffEqBase/c8MAQ/src/solve.jl:993 [inlined]
[35] (::Zygote.Pullback{…})(Δ::ODESolution{…})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[36] loss
@ ~/Desktop/test.jl:84 [inlined]
[37] (::Zygote.Pullback{Tuple{…}, Tuple{…}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface2.jl:0
[38] (::Zygote.var"#75#76"{Zygote.Pullback{Tuple{…}, Tuple{…}}})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:91
[39] gradient(f::Function, args::Vector{Float64})
@ Zygote ~/.julia/packages/Zygote/nsBv0/src/compiler/interface.jl:148
[40] top-level scope
@ ~/Desktop/test.jl:91
Some type information was truncated. Use `show(err)` to see complete types.
But I haven't been able to isolate it any more.
Metadata
Metadata
Assignees
Labels
No labels