Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[extensions]
DataInterpolationsNDSymbolicsExt = "Symbolics"

[compat]
Adapt = "4.3.0"
Aqua = "0.8"
Expand All @@ -18,6 +21,7 @@ KernelAbstractions = "0.9.34"
Random = "1"
RecipesBase = "1.3.4"
SafeTestsets = "0.1"
Symbolics = "5.29"
Test = "1"
julia = "1"

Expand All @@ -28,7 +32,8 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "DataInterpolations", "ForwardDiff", "Pkg", "Random", "SafeTestsets", "Test"]
test = ["Aqua", "DataInterpolations", "ForwardDiff", "Pkg", "Random", "SafeTestsets", "Symbolics", "Test"]
100 changes: 100 additions & 0 deletions ext/DataInterpolationsNDSymbolicsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
module DataInterpolationsNDSymbolicsExt

using DataInterpolationsND: NDInterpolation
using Symbolics
using Symbolics: Num, unwrap, SymbolicUtils

# Register just one symbolic function - the promote_symtype is handled by the macro
@register_symbolic (interp::NDInterpolation)(t::Real)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
@register_symbolic (interp::NDInterpolation)(t::Real)
@register_symbolic (interp::NDInterpolation)(t)

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@SouthEndMusic is a one-arg call also supported here? I would presume so?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The interp can be called with a tuple of numbers if that's what you mean


Base.nameof(interp::NDInterpolation) = :NDInterpolation

# Add method to handle multiple arguments symbolically
function (interp::NDInterpolation)(args::Vararg{Num})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Easiest to only support the all Num case, can Union{Number,Num} but then all other dispatches need to ::Number.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also needs to define ::Vararg{BasicSymbolic{<:Real}}

unwrapped_args = unwrap.(args)
Symbolics.wrap(SymbolicUtils.term(interp, unwrapped_args...))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also isn't particularly good, it needs to pass the type kwarg to term to make sure the symtype is correct.

end

# Handle direct differentiation of interpolation objects with respect to individual arguments
function Symbolics.derivative(interp::NDInterpolation, args::NTuple{N, Any}, ::Val{I}) where {N, I}
# Create a symbolic term representing the partial derivative
# The I-th argument gets differentiated (1-indexed)
derivative_orders = ntuple(j -> j == I ? 1 : 0, N)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This variable is unused.


# Create a symbolic function call that represents this partial derivative
# We'll use a custom function name to distinguish it from the base interpolation
symbolic_args = Symbolics.wrap.(args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this needs to unwrap not wrap.

Symbolics.unwrap(
SymbolicUtils.term(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again with passing the type kwarg.

PartialDerivative{I}(interp),
unwrap.(symbolic_args)...
)
)
end

# Define a partial derivative wrapper type to carry the differentiation information
struct PartialDerivative{I}
interp::NDInterpolation
end

# Make the partial derivative callable
function (pd::PartialDerivative{I})(args...) where {I}
derivative_orders = ntuple(j -> j == I ? 1 : 0, length(args))
pd.interp(args...; derivative_orders = derivative_orders)
end

# Promote symtype for partial derivatives
SymbolicUtils.promote_symtype(::PartialDerivative, _...) = Real

# Name the partial derivative functions appropriately
Base.nameof(pd::PartialDerivative{I}) where {I} = Symbol("∂$(I)_NDInterpolation")

# Handle higher-order derivatives by chaining partial derivatives
function Symbolics.derivative(pd::PartialDerivative{J}, args::NTuple{N, Any}, ::Val{I}) where {J, N, I}
# Create a new partial derivative that represents higher-order differentiation
new_pd = MixedPartialDerivative(pd.interp, (J, I))
symbolic_args = Symbolics.wrap.(args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need to wrap here, the unwrap in the subsequent line is fine.

Symbolics.unwrap(
SymbolicUtils.term(
new_pd,
unwrap.(symbolic_args)...
)
)
end

# Define mixed partial derivatives for higher-order cases
struct MixedPartialDerivative
interp::NDInterpolation
orders::Tuple{Vararg{Int}}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This struct is type-unstable. Instead of storing the orders like this and counting them when called, it should just store the NTuple derivative_orders as defined in the call and make sure the type is parametric.

end

# Make mixed partial derivatives callable
function (mpd::MixedPartialDerivative)(args...)
derivative_orders = ntuple(length(args)) do j
count(==(j), mpd.orders)
end
mpd.interp(args...; derivative_orders = derivative_orders)
end

# Promote symtype for mixed partial derivatives
SymbolicUtils.promote_symtype(::MixedPartialDerivative, _...) = Real

# Name mixed partial derivatives
function Base.nameof(mpd::MixedPartialDerivative)
orders_str = join(mpd.orders, "_")
Symbol("∂$(orders_str)_NDInterpolation")
end

# Handle further differentiation of mixed partial derivatives
function Symbolics.derivative(mpd::MixedPartialDerivative, args::NTuple{N, Any}, ::Val{I}) where {N, I}
new_mpd = MixedPartialDerivative(mpd.interp, (mpd.orders..., I))
symbolic_args = Symbolics.wrap.(args)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, this should not wrap.

Symbolics.unwrap(
SymbolicUtils.term(
new_mpd,
unwrap.(symbolic_args)...
)
)
end

end # module
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Interpolations" include("test_interpolations.jl")
@safetestset "Derivatives" include("test_derivatives.jl")
@safetestset "DataInterpolations" include("test_datainterpolations_comparison.jl")
elseif GROUP == "Extensions"
@safetestset "Symbolics Extension" include("test_symbolics_ext.jl")
elseif GROUP == "QA"
@safetestset "Aqua" include("aqua.jl")
elseif GROUP == "GPU"
Expand Down
39 changes: 39 additions & 0 deletions test/test_symbolics_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using DataInterpolationsND
using Symbolics
using SafeTestsets

@safetestset "Symbolics Extension" begin
using Test

# Create a simple 2D interpolation
t1 = [1.0, 2.0, 3.0]
t2 = [0.0, 1.0, 2.0]
u = [i + j for i in t1, j in t2] # 3x3 matrix

itp_dims = (
LinearInterpolationDimension(t1),
LinearInterpolationDimension(t2)
)
itp = NDInterpolation(u, itp_dims)

# Test symbolic variables
@variables x y

# Test symbolic evaluation
result = itp(x, y)
@test result isa Symbolics.Num

# Test symbolic differentiation
∂f_∂x = Symbolics.derivative(result, x)
∂f_∂y = Symbolics.derivative(result, y)

@test ∂f_∂x isa Symbolics.Num
@test ∂f_∂y isa Symbolics.Num

# Test that we can substitute values
substituted = Symbolics.substitute(result, Dict(x => 1.5, y => 0.5))

# Compare with numerical evaluation
numerical_result = itp(1.5, 0.5)
@test Float64(substituted) ≈ numerical_result
end
Loading