Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[extensions]
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
DataInterpolationsOptimExt = "Optim"
DataInterpolationsRegularizationToolsExt = "RegularizationTools"
DataInterpolationsSparseConnectivityTracerExt = "SparseConnectivityTracer"
DataInterpolationsSymbolicsExt = "Symbolics"

[compat]
Expand All @@ -40,6 +42,7 @@ RecipesBase = "1.3"
Reexport = "1"
RegularizationTools = "0.6"
SafeTestsets = "0.1"
SparseConnectivityTracer = "1"
StableRNGs = "1"
Symbolics = "5.29, 6"
Test = "1.10"
Expand All @@ -57,6 +60,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
120 changes: 120 additions & 0 deletions ext/DataInterpolationsSparseConnectivityTracerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
module DataInterpolationsSparseConnectivityTracerExt

using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
using FillArrays: Fill # from FillArrays.jl
using DataInterpolations:
AbstractInterpolation,
LinearInterpolation,
QuadraticInterpolation,
LagrangeInterpolation,
AkimaInterpolation,
ConstantInterpolation,
QuadraticSpline,
CubicSpline,
BSplineInterpolation,
BSplineApprox,
CubicHermiteSpline,
# PCHIPInterpolation,
QuinticHermiteSpline,
output_size

#===========#
# Utilities #
#===========#

# Limit support to `u` begin an AbstractVector{<:Number} or AbstractMatrix{<:Number},
# to avoid any cases where the output size is dependent on the input value.
# https://github.com/adrhill/SparseConnectivityTracer.jl/pull/234#discussion_r2031038566

function _sct_interpolate(
::AbstractInterpolation,
uType::Type{<:AbstractVector{<:Number}},
t::GradientTracer,
is_der_1_zero,
is_der_2_zero,
)
return gradient_tracer_1_to_1(t, is_der_1_zero)
end
function _sct_interpolate(
::AbstractInterpolation,
uType::Type{<:AbstractVector{<:Number}},
t::HessianTracer,
is_der_1_zero,
is_der_2_zero,
)
return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
end
function _sct_interpolate(
interp::AbstractInterpolation,
uType::Type{<:AbstractMatrix{<:Number}},
t::GradientTracer,
is_der_1_zero,
is_der_2_zero,
)
t = gradient_tracer_1_to_1(t, is_der_1_zero)
N = only(output_size(interp))
return Fill(t, N)
end
function _sct_interpolate(
interp::AbstractInterpolation,
uType::Type{<:AbstractMatrix{<:Number}},
t::HessianTracer,
is_der_1_zero,
is_der_2_zero,
)
t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
N = only(output_size(interp))
return Fill(t, N)
end

#===========#
# Overloads #
#===========#

# We assume that with the exception of ConstantInterpolation and LinearInterpolation,
# all interpolations have a non-zero second derivative at some point in the input domain.

for (I, is_der1_zero, is_der2_zero) in (
(:ConstantInterpolation, true, true),
(:LinearInterpolation, false, true),
(:QuadraticInterpolation, false, false),
(:LagrangeInterpolation, false, false),
(:AkimaInterpolation, false, false),
(:QuadraticSpline, false, false),
(:CubicSpline, false, false),
(:BSplineInterpolation, false, false),
(:BSplineApprox, false, false),
(:CubicHermiteSpline, false, false),
(:QuinticHermiteSpline, false, false),
)
@eval function (interp::$(I){uType})(
t::AbstractTracer
) where {uType <: AbstractArray{<:Number}}
return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero)
end
end

# Some Interpolations require custom overloads on `Dual` due to mutation of caches.
for I in (
:LagrangeInterpolation,
:BSplineInterpolation,
:BSplineApprox,
:CubicHermiteSpline,
:QuinticHermiteSpline,
)
@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector}
p = interp(primal(d))
t = interp(tracer(d))
return Dual(p, t)
end

@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractMatrix}
p = interp(primal(d))
t = interp(tracer(d))
return Dual.(p, t)
end
end

end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ using SafeTestsets
@safetestset "Integral Tests" include("integral_tests.jl")
@safetestset "Integral Inverse Tests" include("integral_inverse_tests.jl")
@safetestset "Extrapolation Tests" include("extrapolation_tests.jl")
@safetestset "SparseConnectivityTracer Tests" include("online_tests.jl")
@safetestset "Online Tests" include("online_tests.jl")
@safetestset "Regularization Smoothing Tests" include("regularization.jl")
@safetestset "Show methods Tests" include("show.jl")
Expand Down
Loading
Loading