From cc4c6de05ce614db651d8c1d76dfca89177e6a9c Mon Sep 17 00:00:00 2001 From: Tim Holy Date: Sun, 19 Aug 2018 06:58:27 -0500 Subject: [PATCH] Fix gradient for NoInterp --- src/b-splines/indexing.jl | 14 ++++++++++---- src/nointerp/nointerp.jl | 4 ++-- src/utils.jl | 4 ++++ test/gradient.jl | 13 ++++++++++--- 4 files changed, 26 insertions(+), 9 deletions(-) diff --git a/src/b-splines/indexing.jl b/src/b-splines/indexing.jl index 60805914..2a5148a7 100644 --- a/src/b-splines/indexing.jl +++ b/src/b-splines/indexing.jl @@ -174,6 +174,9 @@ function expand(coefs::AbstractArray{T,N}, vweights::Tuple{}, ixs::Tuple{}, iexp @inbounds coefs[iexpanded...] # @inbounds is safe because we checked in the original call end +const HasNoInterp{N} = NTuple{N,Tuple{Vararg{<:Union{Number,NoInterp}}}} +expand(coefs::AbstractArray, vweights::HasNoInterp, ixs::Indexes, iexpanded::Vararg{Integer,M}) where {M} = NoInterp() + # _expand1 handles the expansion of a single dimension weight list (of length L) @inline _expand1(coefs, w1, ix1, wrest, ixrest, iexpanded) = w1[1] * expand(coefs, wrest, ixrest, iexpanded..., ix1[1]) + @@ -182,14 +185,17 @@ end w1[1] * expand(coefs, wrest, ixrest, iexpanded..., ix1[1]) # Expansion of the gradient -function expand(coefs, (vweights, gweights)::Tuple{Weights{N},Weights{N}}, ixs::Indexes{N}) where N +function expand(coefs, (vweights, gweights)::Tuple{HasNoInterp{N},HasNoInterp{N}}, ixs::Indexes{N}) where N # We swap in one gradient dimension per call to expand - SVector(ntuple(d->expand(coefs, substitute(vweights, d, gweights), ixs), Val(N))) + SVector(skip_nointerp(ntuple(d->expand(coefs, substitute(vweights, d, gweights), ixs), Val(N))...)) end -function expand!(dest, coefs, (vweights, gweights)::Tuple{Weights{N},Weights{N}}, ixs::Indexes{N}) where N +function expand!(dest, coefs, (vweights, gweights)::Tuple{HasNoInterp{N},HasNoInterp{N}}, ixs::Indexes{N}) where N # We swap in one gradient dimension per call to expand + i = 0 for d = 1:N - dest[d] = expand(coefs, substitute(vweights, d, gweights), ixs) + w = substitute(vweights, d, gweights) + w isa Weights || continue # must have a NoInterp in it + dest[i+=1] = expand(coefs, w, ixs) end dest end diff --git a/src/nointerp/nointerp.jl b/src/nointerp/nointerp.jl index b9ca09f3..085f96b1 100644 --- a/src/nointerp/nointerp.jl +++ b/src/nointerp/nointerp.jl @@ -17,7 +17,7 @@ base_rem(::NoInterp, bounds, x::Number) = Int(x), 0 expand_index(::NoInterp, xi::Number, ax::AbstractUnitRange, δx) = (xi,) value_weights(::NoInterp, δx) = (oneunit(δx),) -gradient_weights(::NoInterp, δx) = (zero(δx),) -hessian_weights(::NoInterp, δx) = (zero(δx),) +gradient_weights(::NoInterp, δx) = (NoInterp(),) +hessian_weights(::NoInterp, δx) = (NoInterp(),) padded_axis(ax::AbstractUnitRange, ::NoInterp) = ax diff --git a/src/utils.jl b/src/utils.jl index 13c025aa..d7263944 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -20,3 +20,7 @@ end function substitute(default::NTuple{N,Any}, d::Integer, val) where N ntuple(i->ifelse(i==d, val, default[i]), Val(N)) end + +@inline skip_nointerp(x, rest...) = (x, skip_nointerp(rest...)...) +@inline skip_nointerp(::NoInterp, rest...) = skip_nointerp(rest...) +skip_nointerp() = () diff --git a/test/gradient.jl b/test/gradient.jl index 7f5fce9e..816929d4 100644 --- a/test/gradient.jl +++ b/test/gradient.jl @@ -106,6 +106,7 @@ using Test, Interpolations, DualNumbers, LinearAlgebra @test g[2] ≈ 2 * (4 - 1.75) ^ 2 * (3 - 1.75) A2 = rand(Float64, nx, nx) * 100 + gni = [1.0] for BC in (Flat,Line,Free,Periodic,Reflect,Natural), GT in (OnGrid, OnCell) itp_a = interpolate(A2, (BSpline(Linear()), BSpline(Quadratic(BC()))), GT()) itp_b = interpolate(A2, (BSpline(Quadratic(BC())), BSpline(Linear())), GT()) @@ -127,11 +128,17 @@ using Test, Interpolations, DualNumbers, LinearAlgebra @test epsilon(itp_b(x,yd)) ≈ gtmp[2] ix, iy = round(Int, x), round(Int, y) gtmp = Interpolations.gradient(itp_c, ix, y) - @test_broken length(gtmp) == 1 - @test_broken epsilon(itp_c(ix,yd)) ≈ gtmp[1] + @test length(gtmp) == 1 + @test epsilon(itp_c(ix,yd)) ≈ gtmp[1] + gni[1] = NaN + Interpolations.gradient!(gni, itp_c, ix, y) + @test gni[1] ≈ gtmp[1] gtmp = Interpolations.gradient(itp_d, x, iy) - @test_broken length(gtmp) == 1 + @test length(gtmp) == 1 @test epsilon(itp_d(xd,iy)) ≈ gtmp[1] + gni[1] = NaN + Interpolations.gradient!(gni, itp_d, x, iy) + @test gni[1] ≈ gtmp[1] end end