Skip to content

Commit

Permalink
Make spectral_mixture_kernel type stable for StaticArrays (#501)
Browse files Browse the repository at this point in the history
* Cherry pick from branch

* Format and patch bump

* Remove StaticArrays from main Project.toml
  • Loading branch information
theogf authored Apr 18, 2023
1 parent 8746034 commit 9da7bfd
Show file tree
Hide file tree
Showing 5 changed files with 19 additions and 2 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.53"
version = "0.10.54"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
8 changes: 7 additions & 1 deletion src/basekernels/sm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Here, D is input dimension and A is the number of spectral components.
`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.
!!! warning
If you want to make sure that the constructor is type-stable, you should
provide [`StaticArrays`](https://github.com/JuliaArrays/StaticArrays.jl) arguments:
`αs` as a `StaticVector`, `γs` and `ωs` as `StaticMatrix`.
Generalised Spectral Mixture kernel function. This family of functions is dense
in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]
Expand Down Expand Up @@ -42,11 +47,12 @@ function spectral_mixture_kernel(
throw(DimensionMismatch("The dimensions of γs ans ωs do not match"))
end

return sum(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
kernels = map(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
a = TransformedKernel(h, LinearTransform'))
b = TransformedKernel(CosineKernel(), LinearTransform'))
return α * a * b
end
return sum(kernels)
end

function spectral_mixture_kernel(
Expand Down
2 changes: 2 additions & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
Expand All @@ -32,4 +33,5 @@ PDMats = "0.9, 0.10, 0.11"
ReverseDiff = "1.2"
SpecialFunctions = "0.10, 1, 2"
StableRNGs = "1"
StaticArrays = "1"
Zygote = "0.6.38"
8 changes: 8 additions & 0 deletions test/basekernels/sm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,14 @@
TestUtils.test_interface(k1, x0, x1, x2)
TestUtils.test_interface(k2, x0, x1, x2)
end

@testset "Type stability given static arrays" begin
αs = @SVector rand(3)
γs = @SMatrix rand(D_in, 3)
ωs = @SMatrix rand(D_in, 3)
@inferred spectral_mixture_kernel(αs, γs, ωs)
end

# test_ADs(x->spectral_mixture_kernel(exp.(x[1:3]), reshape(x[4:18], 5, 3), reshape(x[19:end], 5, 3)), vcat(log.(αs₁), γs[:], ωs[:]), dims = [5,5])
@test_broken "No tests passing (BaseKernel)"
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ using PDMats
using Random
using SpecialFunctions
using StableRNGs
using StaticArrays
using Statistics
using Test
using Zygote: Zygote
Expand Down

2 comments on commit 9da7bfd

@theogf
Copy link
Member Author

@theogf theogf commented on 9da7bfd Apr 18, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/81844

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.54 -m "<description of version>" 9da7bfd38a7a04ab4575acc882f9ce0b0a0919be
git push origin v0.10.54

Please sign in to comment.