Description
I need to differentiate through the DataInterpolations.jl
wrt to the new evaluation points and the input data.
Here a MWE
n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(rand(n-2)), [1.])
y = rand(n);
function di_spline(y,x,xn)
spline = QuadraticSpline(y,x, extrapolate = true)
return spline.(xn)
end
Although computing such an interpolation is very efficient
@benchmark sum(di_spline($y,$x,$x1))
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
Range (min … max): 1.917 μs … 336.010 μs ┊ GC (min … max): 0.00% … 98.41%
Time (median): 2.036 μs ┊ GC (median): 0.00%
Time (mean ± σ): 2.212 μs ± 4.597 μs ┊ GC (mean ± σ): 3.49% ± 2.40%
▂▆██▇▆▅▄▂▂▆▆▅▂▁ ▂▁▁ ▂
█████████████████▇▇████▆▇▇▆███████▆▅▅▆▅▅▅▅▅▆▄▆▄▄▄▆▆▆▇▆▆▆▆▅▆ █
1.92 μs Histogram: log(frequency) by time 3.46 μs <
Memory estimate: 3.42 KiB, allocs estimate: 7.
The computation of the gradient is three orders of magnitudes slower
@benchmark gradient($y->sum(di_spline($y,$x,$x1)), $y)
BenchmarkTools.Trial: 1288 samples with 1 evaluation.
Range (min … max): 3.580 ms … 7.025 ms ┊ GC (min … max): 0.00% … 44.98%
Time (median): 3.727 ms ┊ GC (median): 0.00%
Time (mean ± σ): 3.880 ms ± 565.918 μs ┊ GC (mean ± σ): 3.12% ± 8.32%
▆██▇▆▅▄▁▁▁
██████████▇█▆▅▅▄▁▅▅▄▄▁▅▁▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▇▅▆▇▄▅▅▄▅▆▄▅▇▅▇ █
3.58 ms Histogram: log(frequency) by time 6.45 ms <
Memory estimate: 2.31 MiB, allocs estimate: 38892.
As a solution, I think proper adjoints need to be added to the library.
I have opened a discourse thread here and given @ChrisRackauckas answer, I opened this issue.
The main question is: assuming there is some interest in adding this feature, how to do it? In the aforementioned thread, I did it rewriting the whole function in DataInterpolations
, dividing it in smaller functions and writing the adjoint for each of them. How would you like to proceed? Writing the constructor adjoint should not be an issue, I am more concerned about writing the adjoint for the interpolation
function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
Cᵢ = A.u[idx - 1]
σ = 1 // 2 * (A.z[idx] - A.z[idx - 1]) / (A.t[idx] - A.t[idx - 1])
return A.z[idx - 1] * (t - A.t[idx - 1]) + σ * (t - A.t[idx - 1])^2 + Cᵢ, idx
end
If anyone is willing to help me with one case at least (this one with QuadraticSpline
, for instance) I will implement the same for other cases.
Thank you in advance!