Skip to content
47 changes: 28 additions & 19 deletions src/DataInterpolationsND.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,38 @@ the size of `u` along that dimension must match the length of `t` of the corresp
- `u`: The array to be interpolated.
"""
struct NDInterpolation{
N_in, N_out,
ID <: AbstractInterpolationDimension,
N,
N_in,
N_out,
gType <: AbstractInterpolationCache,
D,
uType <: AbstractArray
}
u::uType
interp_dims::NTuple{N_in, ID}
interp_dims::D
cache::gType
function NDInterpolation(u, interp_dims, cache)
if interp_dims isa AbstractInterpolationDimension
interp_dims = (interp_dims,)
end
N_in = length(interp_dims)
N_out = ndims(u) - N_in
function NDInterpolation(u::AbstractArray{<:Any,N}, interp_dims, cache) where N
interp_dims = _add_trailing_interp_dims(interp_dims, Val{N}())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With the new implementation @assert N_out≥0 must be checked in _add_trailing_interp_dims (not more interp dims than than the dimensionality of u).

N_in = _count_interpolating_dims(interp_dims)
N_out = _count_noninterpolating_dims(interp_dims)
@assert N_out≥0 "The number of dimensions of u must be at least the number of interpolation dimensions."
validate_size_u(interp_dims, u)
validate_cache(cache, interp_dims, u)
new{N_in, N_out, eltype(interp_dims), typeof(cache), typeof(u)}(
new{N, N_in, N_out, typeof(cache), typeof(interp_dims), typeof(u)}(
u, interp_dims, cache
)
end
end

# TODO probably not type-stable (this needs to compile away completely)
_count_interpolating_dims(interp_dims) = count(map(d -> !(d isa NoInterpolationDimension), interp_dims))
_count_noninterpolating_dims(interp_dims) = count(map(d -> d isa NoInterpolationDimension, interp_dims))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
_count_noninterpolating_dims(interp_dims) = count(map(d -> d isa NoInterpolationDimension, interp_dims))
_count_noninterpolating_dims(interp_dims::NTuple{N}} where {N} = N - _count_interpolating_dims(interp_dims)


_add_trailing_interp_dims(dim::AbstractInterpolationDimension, n) =
_add_trailing_interp_dims((dim,), n)
_add_trailing_interp_dims(dims::Tuple, ::Val{N}) where N =
(dims..., ntuple(_ -> NoInterpolationDimension(), Val{N-length(dims)}())...)

# Constructor with optional global cache
function NDInterpolation(u, interp_dims; cache = EmptyCache())
NDInterpolation(u, interp_dims, cache)
Expand All @@ -70,15 +79,15 @@ function (interp::NDInterpolation)(
end

# In place single input evaluation
function (interp::NDInterpolation{N_in})(
out::Union{Number, AbstractArray{<:Number}},
t::Tuple{Vararg{Number, N_in}};
derivative_orders::NTuple{N_in, <:Integer} = ntuple(_ -> 0, N_in)
) where {N_in}
validate_derivative_orders(derivative_orders, interp)
function (interp::NDInterpolation{N,N_in,N_out})(
out::Union{Number, AbstractArray{<:Number, N_out}},
t::Tuple{Vararg{Number, N}};
derivative_orders::NTuple{N, <:Integer} = ntuple(_ -> 0, N)
) where {N,N_in,N_out}
validate_size_u(interp, out)
validate_derivative_order(derivative_orders, interp)
idx = get_idx(interp.interp_dims, t)
@assert size(out)==size(interp.u)[(N_in + 1):end] "The size of out must match the size of the last N_out dimensions of u."
_interpolate!(out, interp, t, idx, derivative_orders, nothing)
return _interpolate!(out, interp, t, idx, derivative_orders, nothing)
end

# Out of place single input evaluation
Expand All @@ -88,7 +97,7 @@ function (interp::NDInterpolation)(t::Tuple{Vararg{Number}}; kwargs...)
end

export NDInterpolation, LinearInterpolationDimension, ConstantInterpolationDimension,
BSplineInterpolationDimension, NURBSWeights,
BSplineInterpolationDimension, NURBSWeights, NoInterpolationDimension,
eval_unstructured, eval_unstructured!, eval_grid, eval_grid!

end # module DataInterpolationsND
19 changes: 10 additions & 9 deletions src/interpolation_dimensions.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
"""
NoInterpolationDimensio

A dimension that does not perform interpolation.
"""
struct NoInterpolationDimension <: AbstractInterpolationDimension end

"""
LinearInterpolationDimension(t; t_eval = similar(t, 0))

Expand Down Expand Up @@ -157,15 +164,9 @@ function BSplineInterpolationDimension(
synchronize(backend)

idx_eval = similar(t_eval, Int)
basis_function_eval = similar(
t_eval,
typeof(inv(one(eltype(t))) * inv(one(eltype(t_eval)))),
(
length(t_eval),
degree + 1,
max_derivative_order_eval + 1
)
)
T = typeof(inv(one(eltype(t))) * inv(one(eltype(t_eval))))
s = (length(t_eval), degree + 1, max_derivative_order_eval + 1)
basis_function_eval = similar(t_eval, T, s)
itp_dim = BSplineInterpolationDimension(
t, knots_all, t_eval, idx_eval, degree, max_derivative_order_eval,
basis_function_eval, multiplicities)
Expand Down
192 changes: 82 additions & 110 deletions src/interpolation_methods.jl
Original file line number Diff line number Diff line change
@@ -1,137 +1,109 @@
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
out::Union{Number, AbstractArray{<:Any, N_out}},
A::NDInterpolation{N, N_in, N_out},
ts::Tuple{Vararg{Any, N}},
idx::Tuple{Vararg{Any ,N}},
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why allow the index to be any type?

derivative_orders::Tuple{Vararg{Any, N}},
multi_point_index
) where {N_in, N_out, ID <: LinearInterpolationDimension}
out = make_zero!!(out)
any(>(1), derivative_orders) && return out

tᵢ = ntuple(i -> A.interp_dims[i].t[idx[i]], N_in)
tᵢ₊₁ = ntuple(i -> A.interp_dims[i].t[idx[i] + 1], N_in)
) where {N,N_in,N_out}
(; interp_dims, cache, u) = A

# Size of the (hyper)rectangle `t` is in
t_vol = one(eltype(tᵢ))
for (t₁, t₂) in zip(tᵢ, tᵢ₊₁)
t_vol *= t₂ - t₁
out, valid_derivative_orders = check_derivative_order(interp_dims, derivative_orders, ts, out)
valid_derivative_orders || return out
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should be more clear that !valid_derivative_orders does not mean that the input is invalid but that it is certain that the output is 0.

if isnothing(multi_point_index)
multi_point_index = map(_ -> nothing, interp_dims)
end
# Setup
out = make_zero!!(out)
denom = zero(eltype(u))
stencils = map(stencil, interp_dims)
preparations = map(prepare, interp_dims, derivative_orders, multi_point_index, ts, idx)

# Loop over the corners of the (hyper)rectangle `t` is in
for I in Iterators.product(ntuple(i -> (false, true), N_in)...)
c = eltype(out)(inv(t_vol))
for (t_, right_point, d, t₁, t₂) in zip(t, I, derivative_orders, tᵢ, tᵢ₊₁)
c *= if right_point
iszero(d) ? t_ - t₁ : one(t_)
else
iszero(d) ? t₂ - t_ : -one(t_)
end
for I in Iterators.product(stencils...)
J = map(index, interp_dims, ts, idx, I)
weights = map(weight, interp_dims, preparations, I)
if cache isa EmptyCache
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd prefer to make the condition on the cache to be cache isa NURBSWeights for explainability; although now that I think about it given the generalizations NURBSWeights could be renamed to something like NodeWeights, PointWeights or just Weights, also renaming the NDInterpolation field from cache to weights.

product = prod(weights)
else
K = removeat(NoInterpolationDimension, J, interp_dims)
product = cache.weights[K...] * prod(weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The naming here is confusing. I propose to keep cache.weights and to rename the other weights to coefficients.

denom += product
end
J = (ntuple(i -> idx[i] + I[i], N_in)..., ..)
if iszero(N_out)
out += c * A.u[J...]
out += product * u[J...]
else
@. out += c * A.u[J...]
out .+= product .* view(u, J...)
end
end
return out
end

function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: ConstantInterpolationDimension}
if any(>(0), derivative_orders)
return if any(i -> !isempty(searchsorted(A.interp_dims[i].t, t[i])), 1:N_in)
typed_nan(out)
if !(cache isa EmptyCache)
if iszero(N_out)
out /= denom
else
out
out ./= denom
end
end
idx = ntuple(
i -> t[i] >= A.interp_dims[i].t[end] ? length(A.interp_dims[i].t) : idx[i], N_in)
if iszero(N_out)
out = A.u[idx...]
else
out .= A.u[idx...]
end

return out
end

# BSpline evaluation
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: BSplineInterpolationDimension}
(; interp_dims) = A

out = make_zero!!(out)
degrees = ntuple(dim_in -> interp_dims[dim_in].degree, N_in)
basis_function_vals = get_basis_function_values_all(
A, t, idx, derivative_orders, multi_point_index
)

for I in CartesianIndices(ntuple(dim_in -> 1:(degrees[dim_in] + 1), N_in))
B_product = prod(dim_in -> basis_function_vals[dim_in][I[dim_in]], 1:N_in)
cp_index = ntuple(
dim_in -> idx[dim_in] + I[dim_in] - degrees[dim_in] - 1, N_in)
if iszero(N_out)
out += B_product * A.u[cp_index...]
function check_derivative_order(dims::Tuple, derivative_orders::Tuple, ts::Tuple, out)
itr = map(tuple, dims, derivative_orders, ts)
# Fold over itr for all dims, combining out and valid
foldl(itr; init=(out, true)) do (acc_out, acc_valid), (d, d_o, t)
dim_out, dim_valid = check_derivative_order(d, d_o, t, acc_out)
dim_out, dim_valid & acc_valid
end
end
check_derivative_order(::AbstractInterpolationDimension, d_o, t, out) = (out, true)
check_derivative_order(::LinearInterpolationDimension, d_o, t, out) = (out, d_o <= 1)
function check_derivative_order(d::ConstantInterpolationDimension, d_o, t, out)
if d_o > 0
# Check if t is on the boundary between constant steps and if so return nans
return if isempty(searchsorted(d.t, t))
(out, false)
else
out .+= B_product * view(A.u, cp_index..., ..)
(typed_nan(out), false)
end
else
(out, true)
end

return out
end

# NURBS evaluation
function _interpolate!(
out,
A::NDInterpolation{N_in, N_out, ID, <:NURBSWeights},
t::Tuple{Vararg{Number, N_in}},
idx::NTuple{N_in, <:Integer},
derivative_orders::NTuple{N_in, <:Integer},
multi_point_index
) where {N_in, N_out, ID <: BSplineInterpolationDimension}
(; interp_dims, cache) = A

out = make_zero!!(out)
degrees = ntuple(dim_in -> interp_dims[dim_in].degree, N_in)
basis_function_vals = get_basis_function_values_all(
A, t, idx, derivative_orders, multi_point_index
function prepare(d::LinearInterpolationDimension, derivative_order, multi_point_index, t, i)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you motivate the need for these preparations? I'm generally not a fan of passing around NamedTuples.

Copy link
Collaborator Author

@rafaqz rafaqz Oct 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a way to do out of loop precalulation of arbitrary things for each dimension, and different interpolatiom methods need different fields.

We can of course use structs specific to each interpolation type, NamedTuple is just easy for iterative development of a new idea.

t₁ = d.t[i]
t₂ = d.t[i + 1]
t_vol_inv = inv(t₂ - t₁)
return (; t, t₁, t₂, t_vol_inv, derivative_order)
end
prepare(::ConstantInterpolationDimension, derivative_orders, multi_point_index, t, i) = (;)
prepare(::NoInterpolationDimension, derivative_orders, multi_point_index, t, i) = (;)
function prepare(d::BSplineInterpolationDimension, derivative_order, multi_point_index, t, i)
# TODO the dim_in arg isn't really needed, so drop it. Currently just 0
basis_function_values = get_basis_function_values(
d, t, i, derivative_order, multi_point_index
)
return (; basis_function_values)
end

denom = zero(eltype(t))

for I in CartesianIndices(ntuple(dim_in -> 1:(degrees[dim_in] + 1), N_in))
B_product = prod(dim_in -> basis_function_vals[dim_in][I[dim_in]], 1:N_in)
cp_index = ntuple(
dim_in -> idx[dim_in] + I[dim_in] - degrees[dim_in] - 1, N_in)
weight = cache.weights[cp_index...]
product = weight * B_product
denom += product
if iszero(N_out)
out += product * A.u[cp_index...]
else
out .+= product * view(A.u, cp_index..., ..)
end
end
stencil(::LinearInterpolationDimension) = (false, true)
stencil(::ConstantInterpolationDimension) = 1
stencil(::NoInterpolationDimension) = 1
stencil(d::BSplineInterpolationDimension) = 1:d.degree + 1

if iszero(N_out)
out /= denom
function weight(::LinearInterpolationDimension, prep::NamedTuple, right_point::Bool)
(; t, t₁, t₂, t_vol_inv, derivative_order) = prep
if right_point
iszero(derivative_order) ? t - t₁ : one(t)
else
out ./= denom
end

return out
iszero(derivative_order) ? t₂ - t : -one(t)
end * t_vol_inv
end
weight(::ConstantInterpolationDimension, prep::NamedTuple, i) = 1
weight(::NoInterpolationDimension, prep::NamedTuple, i) = 1
Comment on lines +102 to +103
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
weight(::ConstantInterpolationDimension, prep::NamedTuple, i) = 1
weight(::NoInterpolationDimension, prep::NamedTuple, i) = 1
weight(::ConstantInterpolationDimension, prep::NamedTuple, i) = true
weight(::NoInterpolationDimension, prep::NamedTuple, i) = true

weight(::BSplineInterpolationDimension, prep::NamedTuple, i) = prep.basis_function_values[i]

index(::LinearInterpolationDimension, t, idx, i) = idx + i
index(d::ConstantInterpolationDimension, t, idx, i) = t >= d.t[end] ? length(d.t) : idx[i]
index(::NoInterpolationDimension, t, idx, i) = Colon()
index(d::BSplineInterpolationDimension, t, idx, i) = idx + i - d.degree - 1
Loading
Loading