-
-
Couldn't load subscription status.
- Fork 6
Add MTK support via Symbolics.jl extension #30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,100 @@ | ||
| module DataInterpolationsNDSymbolicsExt | ||
|
|
||
| using DataInterpolationsND: NDInterpolation | ||
| using Symbolics | ||
| using Symbolics: Num, unwrap, SymbolicUtils | ||
|
|
||
| # Register just one symbolic function - the promote_symtype is handled by the macro | ||
| @register_symbolic (interp::NDInterpolation)(t::Real) | ||
|
|
||
| Base.nameof(interp::NDInterpolation) = :NDInterpolation | ||
|
|
||
| # Add method to handle multiple arguments symbolically | ||
| function (interp::NDInterpolation)(args::Vararg{Num}) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Easiest to only support the all Num case, can There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It also needs to define |
||
| unwrapped_args = unwrap.(args) | ||
| Symbolics.wrap(SymbolicUtils.term(interp, unwrapped_args...)) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This also isn't particularly good, it needs to pass the |
||
| end | ||
|
|
||
| # Handle direct differentiation of interpolation objects with respect to individual arguments | ||
| function Symbolics.derivative(interp::NDInterpolation, args::NTuple{N, Any}, ::Val{I}) where {N, I} | ||
| # Create a symbolic term representing the partial derivative | ||
| # The I-th argument gets differentiated (1-indexed) | ||
| derivative_orders = ntuple(j -> j == I ? 1 : 0, N) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This variable is unused. |
||
|
|
||
| # Create a symbolic function call that represents this partial derivative | ||
| # We'll use a custom function name to distinguish it from the base interpolation | ||
| symbolic_args = Symbolics.wrap.(args) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No, this needs to unwrap not wrap. |
||
| Symbolics.unwrap( | ||
| SymbolicUtils.term( | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again with passing the |
||
| PartialDerivative{I}(interp), | ||
| unwrap.(symbolic_args)... | ||
| ) | ||
| ) | ||
| end | ||
|
|
||
| # Define a partial derivative wrapper type to carry the differentiation information | ||
| struct PartialDerivative{I} | ||
| interp::NDInterpolation | ||
| end | ||
|
|
||
| # Make the partial derivative callable | ||
| function (pd::PartialDerivative{I})(args...) where {I} | ||
| derivative_orders = ntuple(j -> j == I ? 1 : 0, length(args)) | ||
| pd.interp(args...; derivative_orders = derivative_orders) | ||
| end | ||
|
|
||
| # Promote symtype for partial derivatives | ||
| SymbolicUtils.promote_symtype(::PartialDerivative, _...) = Real | ||
|
|
||
| # Name the partial derivative functions appropriately | ||
| Base.nameof(pd::PartialDerivative{I}) where {I} = Symbol("∂$(I)_NDInterpolation") | ||
|
|
||
| # Handle higher-order derivatives by chaining partial derivatives | ||
| function Symbolics.derivative(pd::PartialDerivative{J}, args::NTuple{N, Any}, ::Val{I}) where {J, N, I} | ||
| # Create a new partial derivative that represents higher-order differentiation | ||
| new_pd = MixedPartialDerivative(pd.interp, (J, I)) | ||
| symbolic_args = Symbolics.wrap.(args) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No need to wrap here, the unwrap in the subsequent line is fine. |
||
| Symbolics.unwrap( | ||
| SymbolicUtils.term( | ||
| new_pd, | ||
| unwrap.(symbolic_args)... | ||
| ) | ||
| ) | ||
| end | ||
|
|
||
| # Define mixed partial derivatives for higher-order cases | ||
| struct MixedPartialDerivative | ||
| interp::NDInterpolation | ||
| orders::Tuple{Vararg{Int}} | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This struct is type-unstable. Instead of storing the orders like this and counting them when called, it should just store the NTuple |
||
| end | ||
|
|
||
| # Make mixed partial derivatives callable | ||
| function (mpd::MixedPartialDerivative)(args...) | ||
| derivative_orders = ntuple(length(args)) do j | ||
| count(==(j), mpd.orders) | ||
| end | ||
| mpd.interp(args...; derivative_orders = derivative_orders) | ||
| end | ||
|
|
||
| # Promote symtype for mixed partial derivatives | ||
| SymbolicUtils.promote_symtype(::MixedPartialDerivative, _...) = Real | ||
|
|
||
| # Name mixed partial derivatives | ||
| function Base.nameof(mpd::MixedPartialDerivative) | ||
| orders_str = join(mpd.orders, "_") | ||
| Symbol("∂$(orders_str)_NDInterpolation") | ||
| end | ||
|
|
||
| # Handle further differentiation of mixed partial derivatives | ||
| function Symbolics.derivative(mpd::MixedPartialDerivative, args::NTuple{N, Any}, ::Val{I}) where {N, I} | ||
| new_mpd = MixedPartialDerivative(mpd.interp, (mpd.orders..., I)) | ||
| symbolic_args = Symbolics.wrap.(args) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Again, this should not |
||
| Symbolics.unwrap( | ||
| SymbolicUtils.term( | ||
| new_mpd, | ||
| unwrap.(symbolic_args)... | ||
| ) | ||
| ) | ||
| end | ||
|
|
||
| end # module | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| using DataInterpolationsND | ||
| using Symbolics | ||
| using SafeTestsets | ||
|
|
||
| @safetestset "Symbolics Extension" begin | ||
| using Test | ||
|
|
||
| # Create a simple 2D interpolation | ||
| t1 = [1.0, 2.0, 3.0] | ||
| t2 = [0.0, 1.0, 2.0] | ||
| u = [i + j for i in t1, j in t2] # 3x3 matrix | ||
|
|
||
| itp_dims = ( | ||
| LinearInterpolationDimension(t1), | ||
| LinearInterpolationDimension(t2) | ||
| ) | ||
| itp = NDInterpolation(u, itp_dims) | ||
|
|
||
| # Test symbolic variables | ||
| @variables x y | ||
|
|
||
| # Test symbolic evaluation | ||
| result = itp(x, y) | ||
| @test result isa Symbolics.Num | ||
|
|
||
| # Test symbolic differentiation | ||
| ∂f_∂x = Symbolics.derivative(result, x) | ||
| ∂f_∂y = Symbolics.derivative(result, y) | ||
|
|
||
| @test ∂f_∂x isa Symbolics.Num | ||
| @test ∂f_∂y isa Symbolics.Num | ||
|
|
||
| # Test that we can substitute values | ||
| substituted = Symbolics.substitute(result, Dict(x => 1.5, y => 0.5)) | ||
|
|
||
| # Compare with numerical evaluation | ||
| numerical_result = itp(1.5, 0.5) | ||
| @test Float64(substituted) ≈ numerical_result | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@SouthEndMusic is a one-arg call also supported here? I would presume so?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
interpcan be called with a tuple of numbers if that's what you mean