diff --git a/Project.toml b/Project.toml index ef8435a0e..32486a3d4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.10.53" +version = "0.10.54" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/basekernels/sm.jl b/src/basekernels/sm.jl index 833945295..430c3a998 100644 --- a/src/basekernels/sm.jl +++ b/src/basekernels/sm.jl @@ -12,6 +12,11 @@ Here, D is input dimension and A is the number of spectral components. `h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified. +!!! warning + If you want to make sure that the constructor is type-stable, you should + provide [`StaticArrays`](https://github.com/JuliaArrays/StaticArrays.jl) arguments: + `αs` as a `StaticVector`, `γs` and `ωs` as `StaticMatrix`. + Generalised Spectral Mixture kernel function. This family of functions is dense in the family of stationary real-valued kernels with respect to the pointwise convergence.[1] @@ -42,11 +47,12 @@ function spectral_mixture_kernel( throw(DimensionMismatch("The dimensions of γs ans ωs do not match")) end - return sum(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω) + kernels = map(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω) a = TransformedKernel(h, LinearTransform(γ')) b = TransformedKernel(CosineKernel(), LinearTransform(ω')) return α * a * b end + return sum(kernels) end function spectral_mixture_kernel( diff --git a/test/Project.toml b/test/Project.toml index 768371813..e16f39a6c 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -14,6 +14,7 @@ Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" @@ -32,4 +33,5 @@ PDMats = "0.9, 0.10, 0.11" ReverseDiff = "1.2" SpecialFunctions = "0.10, 1, 2" StableRNGs = "1" +StaticArrays = "1" Zygote = "0.6.38" diff --git a/test/basekernels/sm.jl b/test/basekernels/sm.jl index b91024144..5cbc37243 100644 --- a/test/basekernels/sm.jl +++ b/test/basekernels/sm.jl @@ -48,6 +48,14 @@ TestUtils.test_interface(k1, x0, x1, x2) TestUtils.test_interface(k2, x0, x1, x2) end + + @testset "Type stability given static arrays" begin + αs = @SVector rand(3) + γs = @SMatrix rand(D_in, 3) + ωs = @SMatrix rand(D_in, 3) + @inferred spectral_mixture_kernel(αs, γs, ωs) + end + # test_ADs(x->spectral_mixture_kernel(exp.(x[1:3]), reshape(x[4:18], 5, 3), reshape(x[19:end], 5, 3)), vcat(log.(αs₁), γs[:], ωs[:]), dims = [5,5]) @test_broken "No tests passing (BaseKernel)" end diff --git a/test/runtests.jl b/test/runtests.jl index 0fe962648..f486ddd2a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -10,6 +10,7 @@ using PDMats using Random using SpecialFunctions using StableRNGs +using StaticArrays using Statistics using Test using Zygote: Zygote