Skip to content

Commit

Permalink
Fix gradient for NoInterp
Browse files Browse the repository at this point in the history
  • Loading branch information
timholy committed Aug 19, 2018
1 parent b160f08 commit cc4c6de
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 9 deletions.
14 changes: 10 additions & 4 deletions src/b-splines/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,9 @@ function expand(coefs::AbstractArray{T,N}, vweights::Tuple{}, ixs::Tuple{}, iexp
@inbounds coefs[iexpanded...] # @inbounds is safe because we checked in the original call
end

const HasNoInterp{N} = NTuple{N,Tuple{Vararg{<:Union{Number,NoInterp}}}}
expand(coefs::AbstractArray, vweights::HasNoInterp, ixs::Indexes, iexpanded::Vararg{Integer,M}) where {M} = NoInterp()

# _expand1 handles the expansion of a single dimension weight list (of length L)
@inline _expand1(coefs, w1, ix1, wrest, ixrest, iexpanded) =
w1[1] * expand(coefs, wrest, ixrest, iexpanded..., ix1[1]) +
Expand All @@ -182,14 +185,17 @@ end
w1[1] * expand(coefs, wrest, ixrest, iexpanded..., ix1[1])

# Expansion of the gradient
function expand(coefs, (vweights, gweights)::Tuple{Weights{N},Weights{N}}, ixs::Indexes{N}) where N
function expand(coefs, (vweights, gweights)::Tuple{HasNoInterp{N},HasNoInterp{N}}, ixs::Indexes{N}) where N
# We swap in one gradient dimension per call to expand
SVector(ntuple(d->expand(coefs, substitute(vweights, d, gweights), ixs), Val(N)))
SVector(skip_nointerp(ntuple(d->expand(coefs, substitute(vweights, d, gweights), ixs), Val(N))...))
end
function expand!(dest, coefs, (vweights, gweights)::Tuple{Weights{N},Weights{N}}, ixs::Indexes{N}) where N
function expand!(dest, coefs, (vweights, gweights)::Tuple{HasNoInterp{N},HasNoInterp{N}}, ixs::Indexes{N}) where N
# We swap in one gradient dimension per call to expand
i = 0
for d = 1:N
dest[d] = expand(coefs, substitute(vweights, d, gweights), ixs)
w = substitute(vweights, d, gweights)
w isa Weights || continue # must have a NoInterp in it
dest[i+=1] = expand(coefs, w, ixs)
end
dest
end
Expand Down
4 changes: 2 additions & 2 deletions src/nointerp/nointerp.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ base_rem(::NoInterp, bounds, x::Number) = Int(x), 0
expand_index(::NoInterp, xi::Number, ax::AbstractUnitRange, δx) = (xi,)

value_weights(::NoInterp, δx) = (oneunit(δx),)
gradient_weights(::NoInterp, δx) = (zero(δx),)
hessian_weights(::NoInterp, δx) = (zero(δx),)
gradient_weights(::NoInterp, δx) = (NoInterp(),)
hessian_weights(::NoInterp, δx) = (NoInterp(),)

padded_axis(ax::AbstractUnitRange, ::NoInterp) = ax
4 changes: 4 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ end
function substitute(default::NTuple{N,Any}, d::Integer, val) where N
ntuple(i->ifelse(i==d, val, default[i]), Val(N))
end

@inline skip_nointerp(x, rest...) = (x, skip_nointerp(rest...)...)
@inline skip_nointerp(::NoInterp, rest...) = skip_nointerp(rest...)
skip_nointerp() = ()
13 changes: 10 additions & 3 deletions test/gradient.jl
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
@test g[2] 2 * (4 - 1.75) ^ 2 * (3 - 1.75)

A2 = rand(Float64, nx, nx) * 100
gni = [1.0]
for BC in (Flat,Line,Free,Periodic,Reflect,Natural), GT in (OnGrid, OnCell)
itp_a = interpolate(A2, (BSpline(Linear()), BSpline(Quadratic(BC()))), GT())
itp_b = interpolate(A2, (BSpline(Quadratic(BC())), BSpline(Linear())), GT())
Expand All @@ -127,11 +128,17 @@ using Test, Interpolations, DualNumbers, LinearAlgebra
@test epsilon(itp_b(x,yd)) gtmp[2]
ix, iy = round(Int, x), round(Int, y)
gtmp = Interpolations.gradient(itp_c, ix, y)
@test_broken length(gtmp) == 1
@test_broken epsilon(itp_c(ix,yd)) gtmp[1]
@test length(gtmp) == 1
@test epsilon(itp_c(ix,yd)) gtmp[1]
gni[1] = NaN
Interpolations.gradient!(gni, itp_c, ix, y)
@test gni[1] gtmp[1]
gtmp = Interpolations.gradient(itp_d, x, iy)
@test_broken length(gtmp) == 1
@test length(gtmp) == 1
@test epsilon(itp_d(xd,iy)) gtmp[1]
gni[1] = NaN
Interpolations.gradient!(gni, itp_d, x, iy)
@test gni[1] gtmp[1]
end
end

Expand Down

0 comments on commit cc4c6de

Please sign in to comment.