Skip to content

Update to use ADTypes for specifying AD backend #338

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ authors = ["Nathanael Bosch"]
version = "0.16.2"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
ArrayAllocators = "c9d4266f-a5cb-439d-837c-c97b191379f5"
DiffEqBase = "2b5f629d-d688-5b77-993f-72d75c75574e"
DiffEqCallbacks = "459566f4-90b8-5000-8ac3-15dfb0a30def"
Expand Down Expand Up @@ -48,6 +49,7 @@ DiffEqDevToolsExt = "DiffEqDevTools"
RecipesBaseExt = "RecipesBase"

[compat]
ADTypes = "1.14.0"
ArrayAllocators = "0.3"
BlockArrays = "1"
DiffEqBase = "6.122"
Expand Down
1 change: 1 addition & 0 deletions src/ProbNumDiffEq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ using FiniteHorizonGramians
using FillArrays
using MatrixEquations
using DiffEqCallbacks
using ADTypes

# @reexport using GaussianDistributions

Expand Down
4 changes: 2 additions & 2 deletions src/alg_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ OrdinaryDiffEqDifferentiation.concrete_jac(::AbstractEK) = nothing
OrdinaryDiffEqCore.isfsal(::AbstractEK) = false

for ALG in [:EK1, :DiagonalEK1]
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(::$ALG{CS,AD}) where {CS,AD} =
Val{AD}()
@eval OrdinaryDiffEqDifferentiation._alg_autodiff(alg::$ALG{CS,AD}) where {CS,AD} =
alg.autodiff
@eval OrdinaryDiffEqDifferentiation.alg_difftype(
::$ALG{CS,AD,DiffType},
) where {CS,AD,DiffType} =
Expand Down
63 changes: 38 additions & 25 deletions src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,24 +182,27 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
autodiff::AD
EK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
autodiff=AutoForwardDiff(),
diff_type=Val{:forward}(),
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
covariance_factorization::CF=covariance_structure(EK1, prior, diffusionmodel),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(EK1; diffusionmodel, pn_observation_noise, covariance_factorization)
AD_choice, chunk_size, diff_type =
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
typeof(AD_choice),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
Expand All @@ -215,6 +218,7 @@ struct EK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization,
pn_observation_noise,
covariance_factorization,
AD_choice
)
end
end
Expand All @@ -226,15 +230,16 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization::IT
pn_observation_noise::RT
covariance_factorization::CF
autodiff::AD
DiagonalEK1(;
order=3,
prior::PT=IWP(order),
diffusionmodel::DT=DynamicDiffusion(),
smooth=true,
initialization::IT=TaylorModeInit(num_derivatives(prior)),
chunk_size=Val{0}(),
autodiff=Val{true}(),
diff_type=Val{:forward},
autodiff=AutoForwardDiff(),
diff_type=Val{:forward}(),
standardtag=Val{true}(),
concrete_jac=nothing,
pn_observation_noise::RT=nothing,
Expand All @@ -245,9 +250,11 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
),
) where {PT,DT,IT,RT,CF} = begin
ekargcheck(DiagonalEK1; diffusionmodel, pn_observation_noise, covariance_factorization)
AD_choice, chunk_size, diff_type =
OrdinaryDiffEqCore._process_AD_choice(autodiff, chunk_size, diff_type)
new{
_unwrap_val(chunk_size),
_unwrap_val(autodiff),
typeof(AD_choice),
diff_type,
_unwrap_val(standardtag),
_unwrap_val(concrete_jac),
Expand All @@ -263,6 +270,7 @@ struct DiagonalEK1{CS,AD,DiffType,ST,CJ,PT,DT,IT,RT,CF} <: AbstractEK
initialization,
pn_observation_noise,
covariance_factorization,
AD_choice
)
end
end
Expand Down Expand Up @@ -334,16 +342,17 @@ RosenbrockExpEK(; order=3, kwargs...) =
EK1(; prior=IOUP(order, update_rate_parameter=true), kwargs...)

function DiffEqBase.remake(thing::EK1{CS,AD,DT,ST,CJ}; kwargs...) where {CS,AD,DT,ST,CJ}
if haskey(kwargs, :autodiff) && kwargs[:autodiff] isa AutoForwardDiff
chunk_size = OrdinaryDiffEqCore._get_fwd_chunksize(kwargs[:autodiff])
else
chunk_size = Val{CS}()
end

T = SciMLBase.remaker_of(thing)
T(;
SciMLBase.struct_as_namedtuple(thing)...,
chunk_size=Val{CS}(),
autodiff=Val{AD}(),
standardtag=Val{ST}(),
T(; SciMLBase.struct_as_namedtuple(thing)...,
chunk_size=chunk_size, autodiff=thing.autodiff, standardtag=Val{ST}(),
concrete_jac=CJ === nothing ? CJ : Val{CJ}(),
diff_type=DT,
kwargs...,
)
kwargs...)
end

function DiffEqBase.prepare_alg(
Expand All @@ -357,21 +366,25 @@ function DiffEqBase.prepare_alg(
# use the prepare_alg from OrdinaryDiffEqCore; but right now, we do not use `linsolve` which
# is a requirement.

if (isbitstype(T) && sizeof(T) > 24) || (
prob.f isa ODEFunction &&
prob.f.f isa FunctionWrappersWrappers.FunctionWrappersWrapper
)
return remake(alg, chunk_size=Val{1}())
end
prepped_AD = OrdinaryDiffEqDifferentiation.prepare_ADType(OrdinaryDiffEqDifferentiation.alg_autodiff(alg), prob, u0, p, OrdinaryDiffEqDifferentiation.standardtag(alg))

sparse_prepped_AD = OrdinaryDiffEqDifferentiation.prepare_user_sparsity(prepped_AD, prob)

L = StaticArrayInterface.known_length(typeof(u0))
@assert L === nothing "ProbNumDiffEq.jl does not support StaticArrays yet."

x = if prob.f.colorvec === nothing
length(u0)
if (
(
(eltype(u0) <: Complex) ||
(!(prob.f isa DAEFunction) && prob.f.mass_matrix isa MatrixOperator)
) && sparse_prepped_AD isa AutoSparse
)
@warn "Input type or problem definition is incompatible with sparse automatic differentiation. Switching to using dense automatic differentiation."
autodiff = ADTypes.dense_ad(sparse_prepped_AD)
else
maximum(prob.f.colorvec)
autodiff = sparse_prepped_AD
end
cs = ForwardDiff.pickchunksize(x)
return remake(alg, chunk_size=Val{cs}())


return remake(alg, autodiff = autodiff)
end
13 changes: 13 additions & 0 deletions src/solution.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,19 @@ ProbODESolution{T,N}(
pnstats, prob, alg, interp, cache, dense, tslocation, stats, retcode,
)


function SciMLBase.constructorof(
::Type{
ProbNumDiffEq.ProbODESolution{T,N,uType,puType,uType2,DType,tType,rateType,xType,
diffType,bkType,PN,P,A,IType,
CType,DE}}
) where {T,N,uType,puType,uType2,DType,tType,rateType,xType,
diffType,bkType,PN,P,A,IType,
CType,DE}
ProbODESolution{T,N}
end


function DiffEqBase.solution_new_retcode(sol::ProbODESolution{T,N}, retcode) where {T,N}
return ProbODESolution{T,N}(
sol.u, sol.pu, sol.u_analytic, sol.errors, sol.t, sol.k, sol.x_filt, sol.x_smooth,
Expand Down
19 changes: 13 additions & 6 deletions test/autodiff.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,16 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo
_prob.tspan,
jac=true,
)
prob = remake(prob, p=collect(_prob.p))
#prob = remake(prob, p=collect(_prob.p))
ps = ModelingToolkit.parameter_values(prob)
ps = SciMLStructures.replace(SciMLStructures.Tunable(), ps, [1.0, 2.0, 3.0, 4.0])
prob = remake(prob, p=ps)

function param_to_loss(p)
ps = ModelingToolkit.parameter_values(prob)
ps = SciMLStructures.replace(SciMLStructures.Tunable(), ps, p)
sol = solve(
remake(prob, p=p),
remake(prob, p=ps),
ALG(order=3, smooth=false),
sensealg=SensitivityADPassThrough(),
abstol=1e-3,
Expand All @@ -44,14 +49,16 @@ import ODEProblemLibrary: prob_ode_fitzhughnagumo
return norm(sol.u[end]) # Dummy loss
end

# dldp = FiniteDiff.finite_difference_gradient(param_to_loss, prob.p)
# dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
p, _, _ = SciMLStructures.canonicalize(SciMLStructures.Tunable(), prob.p)

#dldp = FiniteDiff.finite_difference_gradient(param_to_loss, p)
#dldu0 = FiniteDiff.finite_difference_gradient(startval_to_loss, prob.u0)
# For some reason FiniteDiff.jl is not working anymore so we use FiniteDifferences.jl:
dldp = grad(central_fdm(5, 1), param_to_loss, prob.p)[1]
dldp = grad(central_fdm(5, 1), param_to_loss, p)[1]
dldu0 = grad(central_fdm(5, 1), startval_to_loss, prob.u0)[1]

@testset "ForwardDiff.jl" begin
@test ForwardDiff.gradient(param_to_loss, prob.p) ≈ dldp rtol = 1e-2
@test ForwardDiff.gradient(param_to_loss, p) ≈ dldp rtol = 1e-2
@test ForwardDiff.gradient(startval_to_loss, prob.u0) ≈ dldu0 rtol = 5e-2
end

Expand Down
Loading