Skip to content

Make spectral_mixture_kernel type stable for StaticArrays #500

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 12 commits into from
Closed
5 changes: 3 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.53"
version = "0.10.55"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -15,6 +15,7 @@ LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Expand All @@ -35,4 +36,4 @@ SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
StatsBase = "0.32, 0.33"
TensorCore = "0.1"
ZygoteRules = "0.2"
julia = "1.3"
julia = "1.6"
4 changes: 2 additions & 2 deletions src/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ See also: [`ZeroKernel`](@ref)
struct ConstantKernel{Tc<:Real} <: SimpleKernel
c::Vector{Tc}

function ConstantKernel(; c::Real=1.0)
@check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0")
function ConstantKernel(; c::Real=1.0, check_args::Bool=true)
@check_args(ConstantKernel, (c, c >= zero(c), "c ≥ 0"))
return new{typeof(c)}([c])
end
end
Expand Down
10 changes: 6 additions & 4 deletions src/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,14 +125,16 @@ struct GammaExponentialKernel{Tγ<:Real,M} <: SimpleKernel
γ::Vector{Tγ}
metric::M

function GammaExponentialKernel(γ::Real, metric)
@check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")
function GammaExponentialKernel(γ::Real, metric; check_args::Bool=true)
@check_args(GammaExponentialKernel, (γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]"))
return new{typeof(γ),typeof(metric)}([γ], metric)
end
end

function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclidean())
return GammaExponentialKernel(γ, metric)
function GammaExponentialKernel(;
gamma::Real=1.0, γ::Real=gamma, metric=Euclidean(), check_args::Bool=true
)
return GammaExponentialKernel(γ, metric; check_args)
end

@functor GammaExponentialKernel
Expand Down
6 changes: 3 additions & 3 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@ k(x, x'; h) = \\frac{\\|x\\|_2^{2h} + \\|x'\\|_2^{2h} - \\|x - x'\\|^{2h}}{2}.
"""
struct FBMKernel{T<:Real} <: Kernel
h::Vector{T}
function FBMKernel(h::Real)
@check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]")
function FBMKernel(h::Real; check_args::Bool=true)
@check_args(FBMKernel, (h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]"))
return new{typeof(h)}([h])
end
end

FBMKernel(; h::Real=0.5) = FBMKernel(h)
FBMKernel(; h::Real=0.5, check_args::Bool=true) = FBMKernel(h; check_args)

@functor FBMKernel

Expand Down
8 changes: 5 additions & 3 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,15 @@ struct MaternKernel{Tν<:Real,M} <: SimpleKernel
ν::Vector{Tν}
metric::M

function MaternKernel(ν::Real, metric)
@check_args(MaternKernel, ν, ν > zero(ν), "ν > 0")
function MaternKernel(ν::Real, metric; check_args::Bool=true)
@check_args(MaternKernel, (ν, ν > zero(ν), "ν > 0"))
return new{typeof(ν),typeof(metric)}([ν], metric)
end
end

MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric)
function MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean(), check_args::Bool=true)
return MaternKernel(ν, metric; check_args)
end

@functor MaternKernel

Expand Down
10 changes: 7 additions & 3 deletions src/basekernels/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,10 @@ k(x, x'; r) = \\exp\\bigg(- \\frac{1}{2} \\sum_{i=1}^d \\bigg(\\frac{\\sin\\big(
"""
struct PeriodicKernel{T} <: SimpleKernel
r::Vector{T}
function PeriodicKernel(; r::AbstractVector{<:Real}=ones(Float64, 1))
@check_args(PeriodicKernel, r, all(ri > zero(ri) for ri in r), "r > 0")
function PeriodicKernel(;
r::AbstractVector{<:Real}=ones(Float64, 1), check_args::Bool=true
)
@check_args(PeriodicKernel, (r, all(ri > zero(ri) for ri in r), "r > 0"))
return new{eltype(r)}(r)
end
end
Expand All @@ -28,7 +30,9 @@ PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)

Create a [`PeriodicKernel`](@ref) with parameter `r=ones(T, dims)`.
"""
PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims))
function PeriodicKernel(T::DataType, dims::Int=1)
return PeriodicKernel(; r=ones(T, dims), check_args=false)
end

@functor PeriodicKernel

Expand Down
21 changes: 13 additions & 8 deletions src/basekernels/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ See also: [`PolynomialKernel`](@ref)
struct LinearKernel{Tc<:Real} <: SimpleKernel
c::Vector{Tc}

function LinearKernel(c::Real)
@check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
function LinearKernel(c::Real; check_args::Bool=true)
@check_args(LinearKernel, (c, c >= zero(c), "c ≥ 0"))
return new{typeof(c)}([c])
end
end

LinearKernel(; c::Real=0.0) = LinearKernel(c)
LinearKernel(; c::Real=0.0, check_args::Bool=true) = LinearKernel(c; check_args)

@functor LinearKernel

Expand Down Expand Up @@ -69,15 +69,20 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel
degree::Int
c::Vector{Tc}

function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc}
@check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1")
@check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0")
function PolynomialKernel{Tc}(
degree::Int, c::Vector{Tc}; check_args::Bool=true
) where {Tc}
@check_args(
PolynomialKernel,
(degree, degree >= one(degree), "degree ≥ 1"),
(c, only(c) >= zero(Tc), "c ≥ 0")
)
return new{Tc}(degree, c)
end
end

function PolynomialKernel(; degree::Int=2, c::Real=0.0)
return PolynomialKernel{typeof(c)}(degree, [c])
function PolynomialKernel(; degree::Int=2, c::Real=0.0, check_args::Bool=true)
return PolynomialKernel{typeof(c)}(degree, [c]; check_args)
end

# The degree of the polynomial kernel is a fixed discrete parameter
Expand Down
30 changes: 21 additions & 9 deletions src/basekernels/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,16 @@ struct RationalKernel{Tα<:Real,M} <: SimpleKernel
α::Vector{Tα}
metric::M

function RationalKernel(α::Real, metric)
@check_args(RationalKernel, α, α > zero(α), "α > 0")
function RationalKernel(α::Real, metric; check_args::Bool=true)
@check_args(RationalKernel, (α, α > zero(α), "α > 0"))
return new{typeof(α),typeof(metric)}([α], metric)
end
end

function RationalKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean())
return RationalKernel(α, metric)
function RationalKernel(;
alpha::Real=2.0, α::Real=alpha, metric=Euclidean(), check_args::Bool=true
)
return RationalKernel(α, metric; check_args)
end

@functor RationalKernel
Expand Down Expand Up @@ -83,8 +85,10 @@ struct RationalQuadraticKernel{Tα<:Real,M} <: SimpleKernel
α::Vector{Tα}
metric::M

function RationalQuadraticKernel(; alpha::Real=2.0, α::Real=alpha, metric=Euclidean())
@check_args(RationalQuadraticKernel, α, α > zero(α), "α > 0")
function RationalQuadraticKernel(;
alpha::Real=2.0, α::Real=alpha, metric=Euclidean(), check_args::Bool=true
)
@check_args(RationalQuadraticKernel, (α, α > zero(α), "α > 0"))
return new{typeof(α),typeof(metric)}([α], metric)
end
end
Expand Down Expand Up @@ -172,10 +176,18 @@ struct GammaRationalKernel{Tα<:Real,Tγ<:Real,M} <: SimpleKernel
metric::M

function GammaRationalKernel(;
alpha::Real=2.0, gamma::Real=1.0, α::Real=alpha, γ::Real=gamma, metric=Euclidean()
alpha::Real=2.0,
gamma::Real=1.0,
α::Real=alpha,
γ::Real=gamma,
metric=Euclidean(),
check_args::Bool=true,
)
@check_args(GammaRationalKernel, α, α > zero(α), "α > 0")
@check_args(GammaRationalKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")
@check_args(
GammaRationalKernel,
(α, α > zero(α), "α > 0"),
(γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")
)
return new{typeof(α),typeof(γ),typeof(metric)}([α], [γ], metric)
end
end
Expand Down
8 changes: 7 additions & 1 deletion src/basekernels/sm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Here, D is input dimension and A is the number of spectral components.

`h` is the kernel, which defaults to [`SqExponentialKernel`](@ref) if not specified.

!!! warning
If you want to make sure that the constructor is type-stable, you should
provide [`StaticArrays`](https://github.com/JuliaArrays/StaticArrays.jl) arguments:
`αs` as a `StaticVector`, `γs` and `ωs` as `StaticMatrix`.

Generalised Spectral Mixture kernel function. This family of functions is dense
in the family of stationary real-valued kernels with respect to the pointwise convergence.[1]

Expand Down Expand Up @@ -42,11 +47,12 @@ function spectral_mixture_kernel(
throw(DimensionMismatch("The dimensions of γs ans ωs do not match"))
end

return sum(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
kernels = map(zip(αs, eachcol(γs), eachcol(ωs))) do (α, γ, ω)
a = TransformedKernel(h, LinearTransform(γ'))
b = TransformedKernel(CosineKernel(), LinearTransform(ω'))
return α * a * b
end
return sum(kernels)
end

function spectral_mixture_kernel(
Expand Down
4 changes: 2 additions & 2 deletions src/basekernels/wiener.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ The [`WhiteKernel`](@ref) is recovered for ``i = -1``.
[^SDH]: Schober, Duvenaud & Hennig (2014). Probabilistic ODE Solvers with Runge-Kutta Means.
"""
struct WienerKernel{I} <: Kernel
function WienerKernel{I}() where {I}
@check_args(WienerKernel, I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}")
function WienerKernel{I}(; check_args::Bool=true) where {I}
@check_args(WienerKernel, (I, I ∈ (-1, 0, 1, 2, 3), "I ∈ {-1, 0, 1, 2, 3}"))
if I == -1
return WhiteKernel()
end
Expand Down
6 changes: 4 additions & 2 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@ struct ScaledKernel{Tk<:Kernel,Tσ²<:Real} <: Kernel
σ²::Vector{Tσ²}
end

function ScaledKernel(kernel::Tk, σ²::Tσ²=1.0) where {Tk<:Kernel,Tσ²<:Real}
@check_args(ScaledKernel, σ², σ² > zero(Tσ²), "σ² > 0")
function ScaledKernel(
kernel::Tk, σ²::Tσ²=1.0; check_args::Bool=true
) where {Tk<:Kernel,Tσ²<:Real}
@check_args(ScaledKernel, (σ², σ² > zero(Tσ²), "σ² > 0"))
return ScaledKernel{Tk,Tσ²}(kernel, [σ²])
end

Expand Down
18 changes: 8 additions & 10 deletions src/mokernels/independent.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,14 @@ function kernelmatrix(
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end
function kernelmatrix!(
K::AbstractMatrix, k::IndependentMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end

function Base.show(io::IO, k::IndependentMOKernel)
Expand Down
36 changes: 18 additions & 18 deletions src/mokernels/intrinsiccoregion.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,23 +19,25 @@ struct IntrinsicCoregionMOKernel{K<:Kernel,T<:AbstractMatrix} <: MOKernel
kernel::K
B::T

function IntrinsicCoregionMOKernel{K,T}(kernel::K, B::T) where {K,T}
function IntrinsicCoregionMOKernel{K,T}(
kernel::K, B::T; check_args::Bool=true
) where {K,T}
@check_args(
IntrinsicCoregionMOKernel,
B,
eigmin(B) >= 0,
"B has to be positive semi-definite"
(B, eigmin(B) >= 0, "B has to be positive semi-definite")
)
return new{K,T}(kernel, B)
end
end

function IntrinsicCoregionMOKernel(; kernel::Kernel, B::AbstractMatrix)
return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B)
function IntrinsicCoregionMOKernel(;
kernel::Kernel, B::AbstractMatrix, check_args::Bool=true
)
return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B; check_args)
end

function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix)
return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B)
function IntrinsicCoregionMOKernel(kernel::Kernel, B::AbstractMatrix; check_args::Bool=true)
return IntrinsicCoregionMOKernel{typeof(kernel),typeof(B)}(kernel, B; check_args)
end

function (k::IntrinsicCoregionMOKernel)((x, px)::Tuple{Any,Int}, (y, py)::Tuple{Any,Int})
Expand All @@ -57,16 +59,14 @@ function kernelmatrix(
return _kernelmatrix_kron_helper(MOI, Kfeatures, Koutputs)
end

if VERSION >= v"1.6"
function kernelmatrix!(
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end
function kernelmatrix!(
K::AbstractMatrix, k::IntrinsicCoregionMOKernel, x::MOI, y::MOI
) where {MOI<:IsotopicMOInputsUnion}
x.out_dim == y.out_dim ||
throw(DimensionMismatch("`x` and `y` must have the same `out_dim`"))
Kfeatures = kernelmatrix(k.kernel, x.x, y.x)
Koutputs = _mo_output_covariance(k, x.out_dim)
return _kernelmatrix_kron_helper!(K, MOI, Kfeatures, Koutputs)
end

function Base.show(io::IO, k::IntrinsicCoregionMOKernel)
Expand Down
20 changes: 9 additions & 11 deletions src/mokernels/mokernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,14 @@ function _kernelmatrix_kron_helper(::Type{<:MOInputIsotopicByOutputs}, Kfeatures
return kron(Koutputs, Kfeatures)
end

if VERSION >= v"1.6"
function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
)
return kron!(K, Kfeatures, Koutputs)
end
function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByFeatures}, Kfeatures, Koutputs
)
return kron!(K, Kfeatures, Koutputs)
end

function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
)
return kron!(K, Koutputs, Kfeatures)
end
function _kernelmatrix_kron_helper!(
K, ::Type{<:MOInputIsotopicByOutputs}, Kfeatures, Koutputs
)
return kron!(K, Koutputs, Kfeatures)
end
Loading