-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #35 from LuxDL/ap/spline_layer
Add Spline Layer
- Loading branch information
Showing
5 changed files
with
157 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
module BoltzDataInterpolationsExt | ||
|
||
using Boltz: Boltz, Layers | ||
using DataInterpolations: AbstractInterpolation | ||
|
||
for train_grid in (true, false) | ||
grid_expr = train_grid ? :(grid = ps.grid) : :(grid = st.grid) | ||
@eval function (spl::Layers.SplineLayer{$(train_grid), Basis})( | ||
t::AbstractVector, ps, st) where {Basis <: AbstractInterpolation} | ||
$(grid_expr) | ||
interp = __construct_basis(Basis, ps.saved_points, grid; extrapolate=true) | ||
sol = interp.(t) | ||
spl.in_dims == () && return sol, st | ||
return Boltz._stack(sol), st | ||
end | ||
end | ||
|
||
@inline function __construct_basis( | ||
::Type{Basis}, saved_points::AbstractVector, grid; extrapolate=false) where {Basis} | ||
return Basis(saved_points, grid; extrapolate) | ||
end | ||
|
||
@inline function __construct_basis(::Type{Basis}, saved_points::AbstractArray{T, N}, | ||
grid; extrapolate=false) where {Basis, T, N} | ||
return __construct_basis( | ||
# Unfortunately DataInterpolations.jl is not very robust to different array types | ||
# so we have to make a copy | ||
Basis, [copy(selectdim(saved_points, N, i)) for i in 1:size(saved_points, N)], | ||
grid; extrapolate) | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,69 @@ | ||
""" | ||
SplineLayer(in_dims, grid_min, grid_max, grid_step, basis::Type{Basis}; | ||
train_grid::Union{Val, Bool}=Val(false), init_saved_points=nothing) | ||
Constructs a spline layer with the given basis function. | ||
## Arguments | ||
- `in_dims`: input dimensions of the layer. This must be a tuple of integers, to construct | ||
a flat vector of saved_points pass in `()`. | ||
- `grid_min`: minimum value of the grid. | ||
- `grid_max`: maximum value of the grid. | ||
- `grid_step`: step size of the grid. | ||
- `basis`: basis function to use for the interpolation. Currently only the basis functions | ||
from DataInterpolations.jl are supported: | ||
1. `ConstantInterpolation` | ||
2. `LinearInterpolation` | ||
3. `QuadraticInterpolation` | ||
4. `QuadraticSpline` | ||
5. `CubicSpline` | ||
## Keyword Arguments | ||
- `train_grid`: whether to train the grid or not. | ||
- `init_saved_points`: values of the function at multiples of the time step. Initialized | ||
by default to a random vector sampled from the unit normal. Alternatively, can take a | ||
function with the signature | ||
`init_saved_points(rng, in_dims, grid_min, grid_max, grid_step)`. | ||
!!! warning | ||
Currently this layer is limited since it relies on DataInterpolations.jl which doesn't | ||
work with GPU arrays. This will be fixed in the future by extending support to different | ||
basis functions | ||
""" | ||
@concrete struct SplineLayer{TG, B, T} <: AbstractExplicitLayer | ||
grid_min::T | ||
grid_max::T | ||
grid_step::T | ||
basis | ||
in_dims | ||
init_saved_points | ||
end | ||
|
||
function SplineLayer(in_dims::Dims, grid_min, grid_max, grid_step, basis::Type{Basis}; | ||
train_grid::Union{Val, Bool}=Val(false), init_saved_points=nothing) where {Basis} | ||
return SplineLayer{__unwrap_val(train_grid), Basis}( | ||
grid_min, grid_max, grid_step, basis, in_dims, init_saved_points) | ||
end | ||
|
||
function LuxCore.initialparameters( | ||
rng::AbstractRNG, layer::SplineLayer{TG, B, T}) where {TG, B, T} | ||
if layer.init_saved_points === nothing | ||
saved_points = rand(rng, T, layer.in_dims..., | ||
length((layer.grid_min):(layer.grid_step):(layer.grid_max))) | ||
else | ||
saved_points = layer.init_saved_points( | ||
rng, in_dims, layer.grid_min, layer.grid_max, layer.grid_step) | ||
end | ||
TG || return (; saved_points) | ||
return (; | ||
saved_points, grid=collect((layer.grid_min):(layer.grid_step):(layer.grid_max))) | ||
end | ||
|
||
function LuxCore.initialstates(::AbstractRNG, layer::SplineLayer{false}) | ||
return (; grid=collect((layer.grid_min):(layer.grid_step):(layer.grid_max)),) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters