From 48e16d7c8dfea0128428320ba2c070b8cb1d73ec Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Tue, 11 Jun 2024 21:07:51 -0700 Subject: [PATCH] Expand basis functions to operate on arbitrary dimensions --- Project.toml | 2 +- src/basis.jl | 82 ++++++++++++++++++++++++++++++++++++--------- test/layer_tests.jl | 34 +++++++++++++++++++ 3 files changed, 101 insertions(+), 17 deletions(-) diff --git a/Project.toml b/Project.toml index 37f5ea6..450e604 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Boltz" uuid = "4544d5e4-abc5-4dea-817f-29e4c205d9c8" authors = ["Avik Pal and contributors"] -version = "0.3.7" +version = "0.3.8" [deps] ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b" diff --git a/src/basis.jl b/src/basis.jl index 51c81c6..1d71e1e 100644 --- a/src/basis.jl +++ b/src/basis.jl @@ -1,8 +1,10 @@ module Basis +using ArgCheck: @argcheck using ..Boltz: _unsqueeze1 using ChainRulesCore: ChainRulesCore, NoTangent using ConcreteStructs: @concrete +using LuxDeviceUtils: get_device, LuxCPUDevice using Markdown: @doc_str const CRC = ChainRulesCore @@ -11,18 +13,38 @@ const CRC = ChainRulesCore @concrete struct GeneralBasisFunction{name} f n::Int + dim::Int end function Base.show(io::IO, basis::GeneralBasisFunction{name}) where {name} print(io, "Basis.$(name)(order=$(basis.n))") end -@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray) where {name, F} - return basis.f.(1:(basis.n), _unsqueeze1(x)) +@inline function (basis::GeneralBasisFunction{name, F})(x::AbstractArray, + grid::Union{AbstractRange, AbstractVector}=1:1:(basis.n)) where {name, F} + @argcheck length(grid) == basis.n + if basis.dim == 1 # Fast path where we don't need to materialize the range + return basis.f.(grid, _unsqueeze1(x)) + end + + @argcheck ndims(x) + 1 ≥ basis.dim + new_x_size = ntuple( + i -> i == basis.dim ? 1 : (i < basis.dim ? size(x, i) : size(x, i - 1)), + ndims(x) + 1) + x_new = reshape(x, new_x_size) + if grid isa AbstractRange + dev = get_device(x) + grid = dev isa LuxCPUDevice ? collect(grid) : dev(grid) + end + grid_shape = ntuple(i -> i == basis.dim ? basis.n : 1, ndims(x) + 1) + grid_new = reshape(grid, grid_shape) + return basis.f.(grid_new, x_new) end +const DIM_KWARG_DOC = " - `dim::Int=1`: The dimension along which the basis functions are applied." + @doc doc""" - Chebyshev(n) + Chebyshev(n; dim::Int=1) Constructs a Chebyshev basis of the form $[T_{0}(x), T_{1}(x), \dots, T_{n-1}(x)]$ where $T_j(.)$ is the $j^{th}$ Chebyshev polynomial of the first kind. @@ -30,44 +52,64 @@ $T_j(.)$ is the $j^{th}$ Chebyshev polynomial of the first kind. ## Arguments - `n`: number of terms in the polynomial expansion. + +## Keyword Arguments + +$(DIM_KWARG_DOC) """ -Chebyshev(n) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n) +Chebyshev(n; dim::Int=1) = GeneralBasisFunction{:Chebyshev}(__chebyshev, n, dim) @inline __chebyshev(i, x) = @fastmath cos(i * acos(x)) @doc doc""" - Sin(n) + Sin(n; dim::Int=1) Constructs a sine basis of the form $[\sin(x), \sin(2x), \dots, \sin(nx)]$. ## Arguments - `n`: number of terms in the sine expansion. + +## Keyword Arguments + +$(DIM_KWARG_DOC) """ -Sin(n) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n) +Sin(n; dim::Int=1) = GeneralBasisFunction{:Sin}(@fastmath(sin∘*), n, dim) @doc doc""" - Cos(n) + Cos(n; dim::Int=1) Constructs a cosine basis of the form $[\cos(x), \cos(2x), \dots, \cos(nx)]$. ## Arguments - `n`: number of terms in the cosine expansion. + +## Keyword Arguments + +$(DIM_KWARG_DOC) """ -Cos(n) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n) +Cos(n; dim::Int=1) = GeneralBasisFunction{:Cos}(@fastmath(cos∘*), n, dim) @doc doc""" - Fourier(n) + Fourier(n; dim=1) Constructs a Fourier basis of the form -$F_j(x) = j is even ? cos((j÷2)x) : sin((j÷2)x)$ => $[F_0(x), F_1(x), \dots, F_n(x)]$. + +$$F_j(x) = \begin{cases} + cos\left(\frac{j}{2}x\right) & \text{if } j \text{ is even} \\ + sin\left(\frac{j}{2}x\right) & \text{if } j \text{ is odd} +\end{cases}$$ ## Arguments - `n`: number of terms in the Fourier expansion. + +## Keyword Arguments + +$(DIM_KWARG_DOC) """ -Fourier(n) = GeneralBasisFunction{:Fourier}(__fourier, n) +Fourier(n; dim::Int=1) = GeneralBasisFunction{:Fourier}(__fourier, n, dim) @inline @fastmath function __fourier(i, x::AbstractFloat) s, c = sincos(i * x / 2) @@ -96,7 +138,7 @@ end end @doc doc""" - Legendre(n) + Legendre(n; dim::Int=1) Constructs a Legendre basis of the form $[P_{0}(x), P_{1}(x), \dots, P_{n-1}(x)]$ where $P_j(.)$ is the $j^{th}$ Legendre polynomial. @@ -104,8 +146,12 @@ $P_j(.)$ is the $j^{th}$ Legendre polynomial. ## Arguments - `n`: number of terms in the polynomial expansion. + +## Keyword Arguments + +$(DIM_KWARG_DOC) """ -Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n) +Legendre(n; dim::Int=1) = GeneralBasisFunction{:Legendre}(__legendre_poly, n, dim) ## Source: https://github.com/ranocha/PolynomialBases.jl/blob/master/src/legendre.jl @inline function __legendre_poly(i, x) @@ -124,15 +170,19 @@ Legendre(n) = GeneralBasisFunction{:Legendre}(__legendre_poly, n) end @doc doc""" - Polynomial(n) + Polynomial(n; dim::Int=1) -Constructs a Polynomial basis of the form $[1, x, \dots, x^(n-1)]$. +Constructs a Polynomial basis of the form $[1, x, \dots, x^{(n-1)}]$. ## Arguments - `n`: number of terms in the polynomial expansion. + +## Keyword Arguments + +$(DIM_KWARG_DOC) """ -Polynomial(n) = GeneralBasisFunction{:Polynomial}(__polynomial, n) +Polynomial(n; dim::Int=1) = GeneralBasisFunction{:Polynomial}(__polynomial, n, dim) @inline __polynomial(i, x) = x^(i - 1) diff --git a/test/layer_tests.jl b/test/layer_tests.jl index dce3a23..c5f8863 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -108,3 +108,37 @@ end end end end + +@testitem "Basis Functions" setup=[SharedTestSetup] tags=[:layers] begin + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES + @testset "$(basis)" for basis in (Basis.Chebyshev, Basis.Sin, Basis.Cos, + Basis.Fourier, Basis.Legendre, Basis.Polynomial) + x = tanh.(randn(Float32, 2, 4)) |> aType + grid = collect(1:3) |> aType + + fn = basis(3) + @test size(fn(x)) == (3, 2, 4) + @jet fn(x) + @test size(fn(x, grid)) == (3, 2, 4) + @jet fn(x, grid) + + fn = basis(3; dim=2) + @test size(fn(x)) == (2, 3, 4) + @jet fn(x) + @test size(fn(x, grid)) == (2, 3, 4) + @jet fn(x, grid) + + fn = basis(3; dim=3) + @test size(fn(x)) == (2, 4, 3) + @jet fn(x) + @test size(fn(x, grid)) == (2, 4, 3) + @jet fn(x, grid) + + fn = basis(3; dim=4) + @test_throws ArgumentError fn(x) + + grid = 1:5 |> aType + @test_throws ArgumentError fn(x, grid) + end + end +end