Skip to content

Improving Zygote performance #289




I need to differentiate through the DataInterpolations.jl wrt to the new evaluation points and the input data.
Here a MWE

n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(rand(n-2)), [1.])
y = rand(n);

function di_spline(y,x,xn)
    spline = QuadraticSpline(y,x, extrapolate = true)
    return spline.(xn)

Although computing such an interpolation is very efficient

@benchmark sum(di_spline($y,$x,$x1))
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.917 μs … 336.010 μs  ┊ GC (min … max): 0.00% … 98.41%
 Time  (median):     2.036 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.212 μs ±   4.597 μs  ┊ GC (mean ± σ):  3.49% ±  2.40%

  ▂▆██▇▆▅▄▂▂▆▆▅▂▁              ▂▁▁                            ▂
  █████████████████▇▇████▆▇▇▆███████▆▅▅▆▅▅▅▅▅▆▄▆▄▄▄▆▆▆▇▆▆▆▆▅▆ █
  1.92 μs      Histogram: log(frequency) by time      3.46 μs <

 Memory estimate: 3.42 KiB, allocs estimate: 7.

The computation of the gradient is three orders of magnitudes slower

@benchmark gradient($y->sum(di_spline($y,$x,$x1)), $y)
BenchmarkTools.Trial: 1288 samples with 1 evaluation.
 Range (min … max):  3.580 ms …   7.025 ms  ┊ GC (min … max): 0.00% … 44.98%
 Time  (median):     3.727 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.880 ms ± 565.918 μs  ┊ GC (mean ± σ):  3.12% ±  8.32%

  ██████████▇█▆▅▅▄▁▅▅▄▄▁▅▁▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▇▅▆▇▄▅▅▄▅▆▄▅▇▅▇ █
  3.58 ms      Histogram: log(frequency) by time      6.45 ms <

 Memory estimate: 2.31 MiB, allocs estimate: 38892.

As a solution, I think proper adjoints need to be added to the library.
I have opened a discourse thread here and given @ChrisRackauckas answer, I opened this issue.
The main question is: assuming there is some interest in adding this feature, how to do it? In the aforementioned thread, I did it rewriting the whole function in DataInterpolations, dividing it in smaller functions and writing the adjoint for each of them. How would you like to proceed? Writing the constructor adjoint should not be an issue, I am more concerned about writing the adjoint for the interpolation

function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
    idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
    Cᵢ = A.u[idx - 1]
    σ = 1 // 2 * (A.z[idx] - A.z[idx - 1]) / (A.t[idx] - A.t[idx - 1])
    return A.z[idx - 1] * (t - A.t[idx - 1]) + σ * (t - A.t[idx - 1])^2 + Cᵢ, idx

If anyone is willing to help me with one case at least (this one with QuadraticSpline, for instance) I will implement the same for other cases.
Thank you in advance!




No one assigned


    No labels
    No labels


    No type


    No projects


    No milestone


    None yet


    No branches or pull requests

    Issue actions