diff --git a/Project.toml b/Project.toml index 450e604..7f1394d 100644 --- a/Project.toml +++ b/Project.toml @@ -24,11 +24,13 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" WeightInitializers = "d49dbf32-c5c2-4618-8acc-27bb2598ef2d" [weakdeps] +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Metalhead = "dbeba491-748d-5e0e-a39e-b530a07fa0cc" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [extensions] +BoltzDataInterpolationsExt = "DataInterpolations" BoltzForwardDiffExt = "ForwardDiff" BoltzMetalheadExt = "Metalhead" BoltzZygoteExt = "Zygote" @@ -41,6 +43,7 @@ Artifacts = "1.10" ChainRulesCore = "1.24" ComponentArrays = "0.15.13" ConcreteStructs = "0.2.3" +DataInterpolations = "5.2.0" ExplicitImports = "1.5" ForwardDiff = "0.10.36" GPUArraysCore = "0.1.6" @@ -70,6 +73,7 @@ julia = "1.10" [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DataInterpolations = "82cc6244-b520-54b8-b5a6-8a565e85f1d0" ExplicitImports = "7d51a73a-1435-4ff3-83d9-f097790105c7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" LuxAMDGPU = "83120cb1-ca15-4f04-bf3b-6967d2e6b60b" @@ -83,4 +87,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ComponentArrays", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test", "Zygote"] +test = ["Aqua", "ComponentArrays", "DataInterpolations", "ExplicitImports", "ForwardDiff", "LuxAMDGPU", "LuxCUDA", "LuxLib", "LuxTestUtils", "Metalhead", "Pkg", "ReTestItems", "Test", "Zygote"] diff --git a/ext/BoltzDataInterpolationsExt.jl b/ext/BoltzDataInterpolationsExt.jl new file mode 100644 index 0000000..12be628 --- /dev/null +++ b/ext/BoltzDataInterpolationsExt.jl @@ -0,0 +1,32 @@ +module BoltzDataInterpolationsExt + +using Boltz: Boltz, Layers +using DataInterpolations: AbstractInterpolation + +for train_grid in (true, false) + grid_expr = train_grid ? :(grid = ps.grid) : :(grid = st.grid) + @eval function (spl::Layers.SplineLayer{$(train_grid), Basis})( + t::AbstractVector, ps, st) where {Basis <: AbstractInterpolation} + $(grid_expr) + interp = __construct_basis(Basis, ps.saved_points, grid; extrapolate=true) + sol = interp.(t) + spl.in_dims == () && return sol, st + return Boltz._stack(sol), st + end +end + +@inline function __construct_basis( + ::Type{Basis}, saved_points::AbstractVector, grid; extrapolate=false) where {Basis} + return Basis(saved_points, grid; extrapolate) +end + +@inline function __construct_basis(::Type{Basis}, saved_points::AbstractArray{T, N}, + grid; extrapolate=false) where {Basis, T, N} + return __construct_basis( + # Unfortunately DataInterpolations.jl is not very robust to different array types + # so we have to make a copy + Basis, [copy(selectdim(saved_points, N, i)) for i in 1:size(saved_points, N)], + grid; extrapolate) +end + +end diff --git a/src/layers/Layers.jl b/src/layers/Layers.jl index fda83df..37af8ac 100644 --- a/src/layers/Layers.jl +++ b/src/layers/Layers.jl @@ -5,7 +5,7 @@ using PrecompileTools: @recompile_invalidations @recompile_invalidations begin using ArgCheck: @argcheck using ADTypes: AutoForwardDiff, AutoZygote - using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _stack + using ..Boltz: Boltz, _fast_chunk, _should_type_assert, _stack, __unwrap_val using ConcreteStructs: @concrete using ChainRulesCore: ChainRulesCore using Lux: Lux, StatefulLuxLayer diff --git a/src/layers/spline.jl b/src/layers/spline.jl index 8b13789..a39505a 100644 --- a/src/layers/spline.jl +++ b/src/layers/spline.jl @@ -1 +1,69 @@ +""" + SplineLayer(in_dims, grid_min, grid_max, grid_step, basis::Type{Basis}; + train_grid::Union{Val, Bool}=Val(false), init_saved_points=nothing) +Constructs a spline layer with the given basis function. + +## Arguments + + - `in_dims`: input dimensions of the layer. This must be a tuple of integers, to construct + a flat vector of saved_points pass in `()`. + + - `grid_min`: minimum value of the grid. + - `grid_max`: maximum value of the grid. + - `grid_step`: step size of the grid. + - `basis`: basis function to use for the interpolation. Currently only the basis functions + from DataInterpolations.jl are supported: + + 1. `ConstantInterpolation` + 2. `LinearInterpolation` + 3. `QuadraticInterpolation` + 4. `QuadraticSpline` + 5. `CubicSpline` + +## Keyword Arguments + + - `train_grid`: whether to train the grid or not. + - `init_saved_points`: values of the function at multiples of the time step. Initialized + by default to a random vector sampled from the unit normal. Alternatively, can take a + function with the signature + `init_saved_points(rng, in_dims, grid_min, grid_max, grid_step)`. + +!!! warning + + Currently this layer is limited since it relies on DataInterpolations.jl which doesn't + work with GPU arrays. This will be fixed in the future by extending support to different + basis functions +""" +@concrete struct SplineLayer{TG, B, T} <: AbstractExplicitLayer + grid_min::T + grid_max::T + grid_step::T + basis + in_dims + init_saved_points +end + +function SplineLayer(in_dims::Dims, grid_min, grid_max, grid_step, basis::Type{Basis}; + train_grid::Union{Val, Bool}=Val(false), init_saved_points=nothing) where {Basis} + return SplineLayer{__unwrap_val(train_grid), Basis}( + grid_min, grid_max, grid_step, basis, in_dims, init_saved_points) +end + +function LuxCore.initialparameters( + rng::AbstractRNG, layer::SplineLayer{TG, B, T}) where {TG, B, T} + if layer.init_saved_points === nothing + saved_points = rand(rng, T, layer.in_dims..., + length((layer.grid_min):(layer.grid_step):(layer.grid_max))) + else + saved_points = layer.init_saved_points( + rng, in_dims, layer.grid_min, layer.grid_max, layer.grid_step) + end + TG || return (; saved_points) + return (; + saved_points, grid=collect((layer.grid_min):(layer.grid_step):(layer.grid_max))) +end + +function LuxCore.initialstates(::AbstractRNG, layer::SplineLayer{false}) + return (; grid=collect((layer.grid_min):(layer.grid_step):(layer.grid_max)),) +end diff --git a/test/layer_tests.jl b/test/layer_tests.jl index c5f8863..b29017d 100644 --- a/test/layer_tests.jl +++ b/test/layer_tests.jl @@ -142,3 +142,54 @@ end end end end + +@testitem "Spline Layer" setup=[SharedTestSetup] tags=[:layers] begin + using ComponentArrays, DataInterpolations, ForwardDiff, Zygote + + @testset "$(mode)" for (mode, aType, dev, ongpu) in MODES + ongpu && continue + + @testset "$(spl): train_grid $(train_grid), dims $(dims)" for spl in ( + ConstantInterpolation, LinearInterpolation, + QuadraticInterpolation, QuadraticSpline, CubicSpline), + train_grid in (true, false), + dims in ((), (8,)) + + spline = Layers.SplineLayer(dims, 0.0f0, 1.0f0, 0.1f0, spl; train_grid) + ps, st = Lux.setup(Xoshiro(0), spline) |> dev + ps_ca = ComponentArray(ps |> cpu_device()) |> dev + + x = tanh.(randn(Float32, 4)) |> aType + + y, st = spline(x, ps, st) + @test size(y) == (dims..., 4) + + opt_broken = !ongpu && dims != () && spl !== ConstantInterpolation + + @jet spline(x, ps, st) opt_broken=opt_broken # See SciML/DataInterpolations.jl/issues/267 + + y, st = spline(x, ps_ca, st) + @test size(y) == (dims..., 4) + + @jet spline(x, ps_ca, st) opt_broken=opt_broken # See SciML/DataInterpolations.jl/issues/267 + + ∂x, ∂ps = Zygote.gradient((x, ps) -> sum(abs2, first(spline(x, ps, st))), x, ps) + spl !== ConstantInterpolation && @test ∂x !== nothing + @test ∂ps !== nothing + + ∂x_fd = ForwardDiff.gradient(x -> sum(abs2, first(spline(x, ps, st))), x) + ∂ps_fd = ForwardDiff.gradient(ps -> sum(abs2, first(spline(x, ps, st))), ps_ca) + + spl !== ConstantInterpolation && @test ∂x≈∂x_fd atol=1e-3 rtol=1e-3 + + @test ∂ps.saved_points≈∂ps_fd.saved_points atol=1e-3 rtol=1e-3 + if train_grid + if ∂ps.grid === nothing + @test all(Base.Fix1(isapprox, 0), ∂ps_fd.grid) + else + @test ∂ps.grid≈∂ps_fd.grid atol=1e-3 rtol=1e-3 + end + end + end + end +end