Skip to content

Improving Zygote performance #289

Open
@marcobonici

Description

@marcobonici

I need to differentiate through the DataInterpolations.jl wrt to the new evaluation points and the input data.
Here a MWE

n = 64
x = vcat([0.], sort(rand(n-2)), [1.])
x1 = vcat([0.], sort(rand(n-2)), [1.])
y = rand(n);

function di_spline(y,x,xn)
    spline = QuadraticSpline(y,x, extrapolate = true)
    return spline.(xn)
end

Although computing such an interpolation is very efficient

@benchmark sum(di_spline($y,$x,$x1))
BenchmarkTools.Trial: 10000 samples with 10 evaluations.
 Range (min … max):  1.917 μs … 336.010 μs  ┊ GC (min … max): 0.00% … 98.41%
 Time  (median):     2.036 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   2.212 μs ±   4.597 μs  ┊ GC (mean ± σ):  3.49% ±  2.40%

  ▂▆██▇▆▅▄▂▂▆▆▅▂▁              ▂▁▁                            ▂
  █████████████████▇▇████▆▇▇▆███████▆▅▅▆▅▅▅▅▅▆▄▆▄▄▄▆▆▆▇▆▆▆▆▅▆ █
  1.92 μs      Histogram: log(frequency) by time      3.46 μs <

 Memory estimate: 3.42 KiB, allocs estimate: 7.

The computation of the gradient is three orders of magnitudes slower

@benchmark gradient($y->sum(di_spline($y,$x,$x1)), $y)
BenchmarkTools.Trial: 1288 samples with 1 evaluation.
 Range (min … max):  3.580 ms …   7.025 ms  ┊ GC (min … max): 0.00% … 44.98%
 Time  (median):     3.727 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   3.880 ms ± 565.918 μs  ┊ GC (mean ± σ):  3.12% ±  8.32%

  ▆██▇▆▅▄▁▁▁                                                   
  ██████████▇█▆▅▅▄▁▅▅▄▄▁▅▁▁▄▁▁▁▁▄▁▁▁▁▁▁▁▁▁▁▁▁▁▇▅▆▇▄▅▅▄▅▆▄▅▇▅▇ █
  3.58 ms      Histogram: log(frequency) by time      6.45 ms <

 Memory estimate: 2.31 MiB, allocs estimate: 38892.

As a solution, I think proper adjoints need to be added to the library.
I have opened a discourse thread here and given @ChrisRackauckas answer, I opened this issue.
The main question is: assuming there is some interest in adding this feature, how to do it? In the aforementioned thread, I did it rewriting the whole function in DataInterpolations, dividing it in smaller functions and writing the adjoint for each of them. How would you like to proceed? Writing the constructor adjoint should not be an issue, I am more concerned about writing the adjoint for the interpolation

function _interpolate(A::QuadraticSpline{<:AbstractVector}, t::Number, iguess)
    idx = get_idx(A.t, t, iguess; lb = 2, ub_shift = 0, side = :first)
    Cᵢ = A.u[idx - 1]
    σ = 1 // 2 * (A.z[idx] - A.z[idx - 1]) / (A.t[idx] - A.t[idx - 1])
    return A.z[idx - 1] * (t - A.t[idx - 1]) + σ * (t - A.t[idx - 1])^2 + Cᵢ, idx
end

If anyone is willing to help me with one case at least (this one with QuadraticSpline, for instance) I will implement the same for other cases.
Thank you in advance!

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions