Skip to content

Commit 3460ec6

Browse files
Add SparseConnectivityTracer.jl v1 Support (#444)
1 parent cc2861e commit 3460ec6

File tree

4 files changed

+407
-0
lines changed

4 files changed

+407
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,15 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1515
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
1616
Optim = "429524aa-4258-5aef-a3af-852621145aeb"
1717
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
18+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
1819
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
1920
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2021

2122
[extensions]
2223
DataInterpolationsChainRulesCoreExt = "ChainRulesCore"
2324
DataInterpolationsOptimExt = "Optim"
2425
DataInterpolationsRegularizationToolsExt = "RegularizationTools"
26+
DataInterpolationsSparseConnectivityTracerExt = "SparseConnectivityTracer"
2527
DataInterpolationsSymbolicsExt = "Symbolics"
2628

2729
[compat]
@@ -40,6 +42,7 @@ RecipesBase = "1.3"
4042
Reexport = "1"
4143
RegularizationTools = "0.6"
4244
SafeTestsets = "0.1"
45+
SparseConnectivityTracer = "1"
4346
StableRNGs = "1"
4447
Symbolics = "5.29, 6"
4548
Test = "1.10"
@@ -57,6 +60,7 @@ Optim = "429524aa-4258-5aef-a3af-852621145aeb"
5760
QuadGK = "1fd47b50-473d-5c70-9696-f719f8f3bcdc"
5861
RegularizationTools = "29dad682-9a27-4bc3-9c72-016788665182"
5962
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
63+
SparseConnectivityTracer = "9f842d2f-2579-4b1d-911e-f412cf18a3f5"
6064
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
6165
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
6266
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
module DataInterpolationsSparseConnectivityTracerExt
2+
3+
using SparseConnectivityTracer: AbstractTracer, Dual, primal, tracer
4+
using SparseConnectivityTracer: GradientTracer, gradient_tracer_1_to_1
5+
using SparseConnectivityTracer: HessianTracer, hessian_tracer_1_to_1
6+
using FillArrays: Fill # from FillArrays.jl
7+
using DataInterpolations:
8+
AbstractInterpolation,
9+
LinearInterpolation,
10+
QuadraticInterpolation,
11+
LagrangeInterpolation,
12+
AkimaInterpolation,
13+
ConstantInterpolation,
14+
QuadraticSpline,
15+
CubicSpline,
16+
BSplineInterpolation,
17+
BSplineApprox,
18+
CubicHermiteSpline,
19+
# PCHIPInterpolation,
20+
QuinticHermiteSpline,
21+
output_size
22+
23+
#===========#
24+
# Utilities #
25+
#===========#
26+
27+
# Limit support to `u` begin an AbstractVector{<:Number} or AbstractMatrix{<:Number},
28+
# to avoid any cases where the output size is dependent on the input value.
29+
# https://github.com/adrhill/SparseConnectivityTracer.jl/pull/234#discussion_r2031038566
30+
31+
function _sct_interpolate(
32+
::AbstractInterpolation,
33+
uType::Type{<:AbstractVector{<:Number}},
34+
t::GradientTracer,
35+
is_der_1_zero,
36+
is_der_2_zero,
37+
)
38+
return gradient_tracer_1_to_1(t, is_der_1_zero)
39+
end
40+
function _sct_interpolate(
41+
::AbstractInterpolation,
42+
uType::Type{<:AbstractVector{<:Number}},
43+
t::HessianTracer,
44+
is_der_1_zero,
45+
is_der_2_zero,
46+
)
47+
return hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
48+
end
49+
function _sct_interpolate(
50+
interp::AbstractInterpolation,
51+
uType::Type{<:AbstractMatrix{<:Number}},
52+
t::GradientTracer,
53+
is_der_1_zero,
54+
is_der_2_zero,
55+
)
56+
t = gradient_tracer_1_to_1(t, is_der_1_zero)
57+
N = only(output_size(interp))
58+
return Fill(t, N)
59+
end
60+
function _sct_interpolate(
61+
interp::AbstractInterpolation,
62+
uType::Type{<:AbstractMatrix{<:Number}},
63+
t::HessianTracer,
64+
is_der_1_zero,
65+
is_der_2_zero,
66+
)
67+
t = hessian_tracer_1_to_1(t, is_der_1_zero, is_der_2_zero)
68+
N = only(output_size(interp))
69+
return Fill(t, N)
70+
end
71+
72+
#===========#
73+
# Overloads #
74+
#===========#
75+
76+
# We assume that with the exception of ConstantInterpolation and LinearInterpolation,
77+
# all interpolations have a non-zero second derivative at some point in the input domain.
78+
79+
for (I, is_der1_zero, is_der2_zero) in (
80+
(:ConstantInterpolation, true, true),
81+
(:LinearInterpolation, false, true),
82+
(:QuadraticInterpolation, false, false),
83+
(:LagrangeInterpolation, false, false),
84+
(:AkimaInterpolation, false, false),
85+
(:QuadraticSpline, false, false),
86+
(:CubicSpline, false, false),
87+
(:BSplineInterpolation, false, false),
88+
(:BSplineApprox, false, false),
89+
(:CubicHermiteSpline, false, false),
90+
(:QuinticHermiteSpline, false, false),
91+
)
92+
@eval function (interp::$(I){uType})(
93+
t::AbstractTracer
94+
) where {uType <: AbstractArray{<:Number}}
95+
return _sct_interpolate(interp, uType, t, $is_der1_zero, $is_der2_zero)
96+
end
97+
end
98+
99+
# Some Interpolations require custom overloads on `Dual` due to mutation of caches.
100+
for I in (
101+
:LagrangeInterpolation,
102+
:BSplineInterpolation,
103+
:BSplineApprox,
104+
:CubicHermiteSpline,
105+
:QuinticHermiteSpline,
106+
)
107+
@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractVector}
108+
p = interp(primal(d))
109+
t = interp(tracer(d))
110+
return Dual(p, t)
111+
end
112+
113+
@eval function (interp::$(I){uType})(d::Dual) where {uType <: AbstractMatrix}
114+
p = interp(primal(d))
115+
t = interp(tracer(d))
116+
return Dual.(p, t)
117+
end
118+
end
119+
120+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using SafeTestsets
88
@safetestset "Integral Tests" include("integral_tests.jl")
99
@safetestset "Integral Inverse Tests" include("integral_inverse_tests.jl")
1010
@safetestset "Extrapolation Tests" include("extrapolation_tests.jl")
11+
@safetestset "SparseConnectivityTracer Tests" include("online_tests.jl")
1112
@safetestset "Online Tests" include("online_tests.jl")
1213
@safetestset "Regularization Smoothing Tests" include("regularization.jl")
1314
@safetestset "Show methods Tests" include("show.jl")

0 commit comments

Comments
 (0)