Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"

[weakdeps]
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"

[extensions]
DataInterpolationsNDSymbolicsExt = "Symbolics"

[compat]
Adapt = "4.3.0"
Aqua = "0.8"
Expand All @@ -18,6 +24,7 @@ KernelAbstractions = "0.9.34"
Random = "1"
RecipesBase = "1.3.4"
SafeTestsets = "0.1"
Symbolics = "6.57"
Test = "1"
julia = "1"

Expand All @@ -28,7 +35,9 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
SymbolicUtils = "d1185830-fcd6-423d-90d6-eec64667417b"
Symbolics = "0c5d862f-8b57-4792-8d23-62f2024744c7"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Aqua", "DataInterpolations", "ForwardDiff", "Pkg", "Random", "SafeTestsets", "Test"]
test = ["Aqua", "DataInterpolations", "ForwardDiff", "Pkg", "Random", "SafeTestsets", "Symbolics", "Test", "SymbolicUtils"]
85 changes: 85 additions & 0 deletions ext/DataInterpolationsNDSymbolicsExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
module DataInterpolationsNDSymbolicsExt

import DataInterpolationsND
using DataInterpolationsND: NDInterpolation
using Symbolics
using Symbolics: Num, unwrap, SymbolicUtils, Symbolic

struct DifferentiatedNDInterpolation{N_in, N_out, I <: NDInterpolation{N_in, N_out}}
interp::I
derivative_orders::NTuple{N_in, Int}
end

function (interp::DifferentiatedNDInterpolation)(args...)
return interp.interp(args; derivative_orders = interp.derivative_orders)
end

Base.nameof(::NDInterpolation) = :NDInterpolation
Base.nameof(::DifferentiatedNDInterpolation) = :DifferentiatedNDInterpolation

for symT in [Num, Symbolic{<:Real}]
@eval function (interp::NDInterpolation{N_in, N_out})(t::Vararg{
$symT, N_in}) where {N_in, N_out}
if $(symT === Num)
t = unwrap.(t)
end
res = if N_out == 0
SymbolicUtils.term(interp, t...; type = Real)
else
Symbolics.array_term(
interp, t...; eltype = Real, container_type = Array, ndims = N_out,
size = DataInterpolationsND.get_output_size(interp))
end
if $(symT === Num)
if N_out == 0
res = Num(res)
else
res = Symbolics.Arr{Num, N_out}(res)
end
end
return res
end
@eval function (interp::DifferentiatedNDInterpolation{N_in, N_out})(t::Vararg{
$symT, N_in}) where {N_in, N_out}
if $(symT === Num)
t = unwrap.(t)
end
res = if N_out == 0
SymbolicUtils.term(interp, t...; type = Real)
else
Symbolics.array_term(
interp, t...; eltype = Real, container_type = Array, ndims = N_out,
size = DataInterpolationsND.get_output_size(interp.interp))
end
if $(symT === Num)
if N_out == 0
res = Num(res)
else
res = Symbolics.Arr{Num, N_out}(res)
end
end
return res
end
end
function SymbolicUtils.promote_symtype(::NDInterpolation{N_in, N_out}, ::Vararg) where {
N_in, N_out}
N_out == 0 ? Real : Array{Real, N_out}
end

function Symbolics.derivative(interp::NDInterpolation{N_in, N_out},
args::NTuple{N_in, Any}, ::Val{I}) where {N_in, N_out, I}
@assert I <= N_in
orders = ntuple(Int ∘ isequal(I), Val{N_in}())
dinterp = DifferentiatedNDInterpolation{N_in, N_out, typeof(interp)}(interp, orders)
return dinterp(args...)
end

function Symbolics.derivative(interp::DifferentiatedNDInterpolation{N_in, N_out},
args::NTuple{N_in, Any}, ::Val{I}) where {N_in, N_out, I}
@assert I <= N_in
orders_offset = ntuple(Int ∘ isequal(I), Val{N_in}())
orders = interp.derivative_orders .+ orders_offset
return typeof(interp)(interp.interp, orders)(args...)
end

end # module
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ if GROUP == "All" || GROUP == "Core"
@safetestset "Interpolations" include("test_interpolations.jl")
@safetestset "Derivatives" include("test_derivatives.jl")
@safetestset "DataInterpolations" include("test_datainterpolations_comparison.jl")
elseif GROUP == "Extensions"
@safetestset "Symbolics Extension" include("test_symbolics_ext.jl")
elseif GROUP == "QA"
@safetestset "Aqua" include("aqua.jl")
elseif GROUP == "GPU"
Expand Down
58 changes: 58 additions & 0 deletions test/test_symbolics_ext.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
using DataInterpolationsND
using Symbolics
import SymbolicUtils as SU
using Symbolics: unwrap
using Test

t1 = cumsum(rand(5))
t2 = cumsum(rand(7))

interpolation_dimensions = (
LinearInterpolationDimension(t1),
LinearInterpolationDimension(t2)
)

u = rand(5, 7, 2)

interp = NDInterpolation(u, interpolation_dimensions)
@variables x y

@testset "Basics" begin
ex = interp(x, y)
@test ex isa Symbolics.Arr
@test size(ex) == (2,)
@test SU.symtype(unwrap(ex)) == Vector{Real}

res = eval(quote
let x = 0.4, y = 0.8
$(SU.Code.toexpr(ex))
end
end)
@test res interp(0.4, 0.8)

ex = interp(unwrap(x), unwrap(y))
@test ex isa SU.BasicSymbolic{Vector{Real}}
end

@testset "Differentiation" begin
ex = interp(x, y)
der = Symbolics.derivative(ex[1], x)
@test size(der) == ()
@test SU.symtype(unwrap(der)) == Real
res = eval(quote
let x = 0.4, y = 0.8
$(SU.Code.toexpr(der))
end
end)
@test res interp(0.4, 0.8; derivative_orders = (1, 0))[1]

der = Symbolics.derivative(ex[1], y)
@test size(der) == ()
@test SU.symtype(unwrap(der)) == Real
res = eval(quote
let x = 0.4, y = 0.8
$(SU.Code.toexpr(der))
end
end)
@test res interp(0.4, 0.8; derivative_orders = (0, 1))[1]
end
Loading