Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use scalar parameters with ParameterHandling #397

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -29,6 +30,7 @@ FillArrays = "0.10, 0.11, 0.12"
Functors = "0.1, 0.2"
IrrationalConstants = "0.1"
LogExpFunctions = "0.2.1, 0.3"
ParameterHandling = "0.4"
Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1"
StatsBase = "0.32, 0.33"
Expand Down
6 changes: 5 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou
export IndependentMOKernel,
LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel

export ParameterKernel

# Reexports
export tensor, ⊗, compose

Expand All @@ -51,11 +53,12 @@ using CompositionsBase
using Distances
using FillArrays
using Functors
using ParameterHandling
using LinearAlgebra
using Requires
using SpecialFunctions: loggamma, besselk, polygamma
using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using LogExpFunctions: logit, logistic, softplus
using StatsBase
using TensorCore
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield
Expand Down Expand Up @@ -107,6 +110,7 @@ include("kernels/kernelproduct.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/overloads.jl")
include("kernels/neuralkernelnetwork.jl")
include("kernels/parameterkernel.jl")
include("approximations/nystrom.jl")
include("generic.jl")

Expand Down
27 changes: 19 additions & 8 deletions src/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@ See also: [`ConstantKernel`](@ref)
"""
struct ZeroKernel <: SimpleKernel end

kappa(κ::ZeroKernel, d::T) where {T<:Real} = zero(T)
@noparams ZeroKernel

kappa(::ZeroKernel, d::Real) = zero(d)

metric(::ZeroKernel) = Delta()

Expand All @@ -35,6 +37,8 @@ k(x, x') = \\delta(x, x').
"""
struct WhiteKernel <: SimpleKernel end

@noparams WhiteKernel

"""
EyeKernel()

Expand Down Expand Up @@ -62,19 +66,26 @@ k(x, x') = c.

See also: [`ZeroKernel`](@ref)
"""
struct ConstantKernel{Tc<:Real} <: SimpleKernel
c::Vector{Tc}
struct ConstantKernel{T<:Real} <: SimpleKernel
c::T

function ConstantKernel(; c::Real=1.0)
function ConstantKernel(c::Real)
@check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
return new{typeof(c)}(c)
end
end

@functor ConstantKernel
ConstantKernel(; c::Real=1.0) = ConstantKernel(c)

function ParameterHandling.flatten(::Type{T}, k::ConstantKernel{S}) where {T<:Real,S}
function unflatten_to_constantkernel(v::Vector{T})
return ConstantKernel(; c=S(exp(only(v))))
end
return T[log(k.c)], unflatten_to_constantkernel
end

kappa(κ::ConstantKernel, x::Real) = first(κ.c) * one(x)
kappa(κ::ConstantKernel, x::Real) = κ.c * one(x)

metric(::ConstantKernel) = Delta()

Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first(κ.c), ")")
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", κ.c, ")")
2 changes: 2 additions & 0 deletions src/basekernels/cosine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ end

CosineKernel(; metric=Euclidean()) = CosineKernel(metric)

@noparams CosineKernel

kappa(::CosineKernel, d::Real) = cospi(d)

metric(k::CosineKernel) = k.metric
Expand Down
27 changes: 19 additions & 8 deletions src/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ end

SqExponentialKernel(; metric=Euclidean()) = SqExponentialKernel(metric)

@noparams SqExponentialKernel

kappa(::SqExponentialKernel, d::Real) = exp(-d^2 / 2)
kappa(::SqExponentialKernel{<:Euclidean}, d²::Real) = exp(-d² / 2)

Expand Down Expand Up @@ -76,6 +78,8 @@ end

ExponentialKernel(; metric=Euclidean()) = ExponentialKernel(metric)

@noparams ExponentialKernel

kappa(::ExponentialKernel, d::Real) = exp(-d)

metric(k::ExponentialKernel) = k.metric
Expand Down Expand Up @@ -121,30 +125,37 @@ See also: [`ExponentialKernel`](@ref), [`SqExponentialKernel`](@ref)

[^RW]: C. E. Rasmussen & C. K. I. Williams (2006). Gaussian Processes for Machine Learning.
"""
struct GammaExponentialKernel{<:Real,M} <: SimpleKernel
γ::Vector{Tγ}
struct GammaExponentialKernel{T<:Real,M} <: SimpleKernel
γ::T
metric::M

function GammaExponentialKernel(γ::Real, metric)
@check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")
return new{typeof(γ),typeof(metric)}([γ], metric)
return new{typeof(γ),typeof(metric)}(γ, metric)
end
end

function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclidean())
return GammaExponentialKernel(γ, metric)
end

@functor GammaExponentialKernel
function ParameterHandling.flatten(
::Type{T}, k::GammaExponentialKernel{S}
) where {T<:Real,S<:Real}
metric = k.metric
function unflatten_to_gammaexponentialkernel(v::Vector{T})
γ = S(2 * logistic(only(v)))
return GammaExponentialKernel(; γ=γ, metric=metric)
end
return T[logit(k.γ / 2)], unflatten_to_gammaexponentialkernel
end

kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^first(κ.γ))
kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^κ.γ)

metric(k::GammaExponentialKernel) = k.metric

iskroncompatible(::GammaExponentialKernel) = true

function Base.show(io::IO, κ::GammaExponentialKernel)
return print(
io, "Gamma Exponential Kernel (γ = ", first(κ.γ), ", metric = ", κ.metric, ")"
)
return print(io, "Gamma Exponential Kernel (γ = ", κ.γ, ", metric = ", κ.metric, ")")
end
2 changes: 2 additions & 0 deletions src/basekernels/exponentiated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ k(x, x') = \\exp(x^\\top x').
"""
struct ExponentiatedKernel <: SimpleKernel end

@noparams ExponentiatedKernel

kappa(::ExponentiatedKernel, xᵀy::Real) = exp(xᵀy)

metric(::ExponentiatedKernel) = DotProduct()
Expand Down
13 changes: 10 additions & 3 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@ k(x, x'; h) = \\frac{\\|x\\|_2^{2h} + \\|x'\\|_2^{2h} - \\|x - x'\\|^{2h}}{2}.
```
"""
struct FBMKernel{T<:Real} <: Kernel
h::Vector{T}
h::T

function FBMKernel(h::Real)
@check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]")
return new{typeof(h)}([h])
return new{typeof(h)}(h)
end
end

FBMKernel(; h::Real=0.5) = FBMKernel(h)

@functor FBMKernel
function ParameterHandling.flatten(::Type{T}, k::FBMKernel{S}) where {T<:Real,S<:Real}
function unflatten_to_fbmkernel(v::Vector{T})
h = S(logistic(only(v)))
return FBMKernel(h)
end
return T[logit(k.h)], unflatten_to_fbmkernel
end

function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
modX = sum(abs2, x)
Expand Down
16 changes: 12 additions & 4 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,23 @@ differentiable in the mean-square sense.

See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref)
"""
struct MaternKernel{<:Real,M} <: SimpleKernel
ν::Vector{Tν}
struct MaternKernel{T<:Real,M} <: SimpleKernel
ν::T
metric::M

function MaternKernel(ν::Real, metric)
@check_args(MaternKernel, ν, ν > zero(ν), "ν > 0")
return new{typeof(ν),typeof(metric)}([ν], metric)
return new{typeof(ν),typeof(metric)}(ν, metric)
end
end

MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric)

@functor MaternKernel
function ParameterHandling.flatten(::Type{T}, k::MaternKernel{S}) where {T<:Real,S<:Real}
metric = k.metric
unflatten_to_maternkernel(v::Vector{T}) = MaternKernel(S(exp(first(v))), metric)
return T[log(k.ν)], unflatten_to_maternkernel
end

@inline function kappa(κ::MaternKernel, d::Real)
result = _matern(first(κ.ν), d)
Expand Down Expand Up @@ -73,6 +77,8 @@ end

Matern32Kernel(; metric=Euclidean()) = Matern32Kernel(metric)

@noparams Matern32Kernel

kappa(::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d)

metric(k::Matern32Kernel) = k.metric
Expand Down Expand Up @@ -104,6 +110,8 @@ end

Matern52Kernel(; metric=Euclidean()) = Matern52Kernel(metric)

@noparams Matern52Kernel

kappa(::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d)

metric(k::Matern52Kernel) = k.metric
Expand Down
2 changes: 2 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ for inputs ``x, x' \\in \\mathbb{R}^d``.[^CW]
"""
struct NeuralNetworkKernel <: Kernel end

@noparams NeuralNetworkKernel

function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end
Expand Down
15 changes: 12 additions & 3 deletions src/basekernels/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,25 @@ struct PeriodicKernel{T} <: SimpleKernel
end
end

"""
PeriodicKernel(dims::Int)

Create a [`PeriodicKernel`](@ref) with parameter `r=ones(Float64, dims)`.
"""
PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)

"""
PeriodicKernel([T=Float64, dims::Int=1])
PeriodicKernel(T, dims::Int=1)

Create a [`PeriodicKernel`](@ref) with parameter `r=ones(T, dims)`.
"""
PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims))
PeriodicKernel(::Type{T}, dims::Int=1) where {T} = PeriodicKernel(; r=ones(T, dims))

@functor PeriodicKernel
function ParameterHandling.flatten(::Type{T}, k::PeriodicKernel{S}) where {T<:Real,S}
vec = T[log(ri) for ri in k.r]
unflatten_to_periodickernel(v::Vector{T}) = PeriodicKernel(; r=S[exp(vi) for vi in v])
return vec, unflatten_to_periodickernel
end

metric(κ::PeriodicKernel) = Sinus(κ.r)

Expand Down
2 changes: 2 additions & 0 deletions src/basekernels/piecewisepolynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ function PiecewisePolynomialKernel(; degree::Int=0, kwargs...)
return PiecewisePolynomialKernel{degree}(; kwargs...)
end

@noparams PiecewisePolynomialKernel

piecewise_polynomial_coefficients(::Val{0}, ::Int) = (1,)
piecewise_polynomial_coefficients(::Val{1}, j::Int) = (1, j + 1)
piecewise_polynomial_coefficients(::Val{2}, j::Int) = (1, j + 2, (j^2 + 4 * j)//3 + 1)
Expand Down
43 changes: 26 additions & 17 deletions src/basekernels/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,29 @@ k(x, x'; c) = x^\\top x' + c.

See also: [`PolynomialKernel`](@ref)
"""
struct LinearKernel{Tc<:Real} <: SimpleKernel
c::Vector{Tc}
struct LinearKernel{T<:Real} <: SimpleKernel
c::T

function LinearKernel(c::Real)
@check_args(LinearKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
return new{typeof(c)}(c)
end
end

LinearKernel(; c::Real=0.0) = LinearKernel(c)

@functor LinearKernel
function ParameterHandling.flatten(::Type{T}, k::LinearKernel{S}) where {T<:Real,S<:Real}
function unflatten_to_linearkernel(v::Vector{T})
return LinearKernel(S(exp(only(v))))
end
return T[log(k.c)], unflatten_to_linearkernel
end

kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + first(κ.c)
kappa(κ::LinearKernel, xᵀy::Real) = xᵀy + κ.c

metric(::LinearKernel) = DotProduct()

Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first(κ.c), ")")
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", κ.c, ")")

"""
PolynomialKernel(; degree::Int=2, c::Real=0.0)
Expand All @@ -47,31 +52,35 @@ k(x, x'; c, \\nu) = (x^\\top x' + c)^\\nu.

See also: [`LinearKernel`](@ref)
"""
struct PolynomialKernel{Tc<:Real} <: SimpleKernel
struct PolynomialKernel{T<:Real} <: SimpleKernel
degree::Int
c::Vector{Tc}
c::T

function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc}
function PolynomialKernel(degree::Int, c::Real)
@check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1")
@check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0")
return new{Tc}(degree, c)
@check_args(PolynomialKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}(degree, c)
end
end

function PolynomialKernel(; degree::Int=2, c::Real=0.0)
return PolynomialKernel{typeof(c)}(degree, [c])
end

# The degree of the polynomial kernel is a fixed discrete parameter
function Functors.functor(::Type{<:PolynomialKernel}, x)
reconstruct_polynomialkernel(xs) = PolynomialKernel{typeof(xs.c)}(x.degree, xs.c)
return (c=x.c,), reconstruct_polynomialkernel
function ParameterHandling.flatten(
::Type{T}, k::PolynomialKernel{S}
) where {T<:Real,S<:Real}
degree = k.degree
function unflatten_to_polynomialkernel(v::Vector{T})
return PolynomialKernel(degree, S(exp(only(v))))
end
return T[log(k.c)], unflatten_to_polynomialkernel
end

kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + first(κ.c))^κ.degree
kappa(κ::PolynomialKernel, xᵀy::Real) = (xᵀy + κ.c)^κ.degree

metric(::PolynomialKernel) = DotProduct()

function Base.show(io::IO, κ::PolynomialKernel)
return print(io, "Polynomial Kernel (c = ", first(κ.c), ", degree = ", κ.degree, ")")
return print(io, "Polynomial Kernel (c = ", κ.c, ", degree = ", κ.degree, ")")
end
Loading