Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Use scalar parameters with ParameterHandling #397

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
IrrationalConstants = "92d709cd-6900-40b7-9082-c6be49f344b6"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
ParameterHandling = "2412ca09-6db7-441c-8e3a-88d5709968c5"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
Expand All @@ -30,6 +31,7 @@ FillArrays = "0.10, 0.11, 0.12, 0.13, 1"
Functors = "0.1, 0.2, 0.3, 0.4"
IrrationalConstants = "0.1, 0.2"
LogExpFunctions = "0.2.1, 0.3"
ParameterHandling = "0.4"
Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1, 2"
Statistics = "1"
Expand Down
6 changes: 5 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ export MOInput, prepare_isotopic_multi_output_data, prepare_heterotopic_multi_ou
export IndependentMOKernel,
LatentFactorMOKernel, IntrinsicCoregionMOKernel, LinearMixingModelKernel

export ParameterKernel

# Reexports
export tensor, ⊗, compose

Expand All @@ -53,11 +55,12 @@ using CompositionsBase
using Distances
using FillArrays
using Functors
using ParameterHandling
using LinearAlgebra
using Requires
using SpecialFunctions: loggamma, besselk, polygamma
using IrrationalConstants: logtwo, twoπ, invsqrt2
using LogExpFunctions: softplus
using LogExpFunctions: logit, logistic, softplus
using StatsBase
using TensorCore
using ZygoteRules: ZygoteRules, AContext, literal_getproperty, literal_getfield
Expand Down Expand Up @@ -111,6 +114,7 @@ include("kernels/kernelproduct.jl")
include("kernels/kerneltensorproduct.jl")
include("kernels/overloads.jl")
include("kernels/neuralkernelnetwork.jl")
include("kernels/parameterkernel.jl")
include("approximations/nystrom.jl")
include("generic.jl")

Expand Down
6 changes: 6 additions & 0 deletions src/TestUtils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using Distances
using LinearAlgebra
using KernelFunctions
using ParameterHandling
using Random
using Test

Expand Down Expand Up @@ -84,6 +85,11 @@
tmp_diag = Vector{Float64}(undef, length(x0))
@test kernelmatrix_diag!(tmp_diag, k, x0) ≈ kernelmatrix_diag(k, x0)
@test kernelmatrix_diag!(tmp_diag, k, x0, x1) ≈ kernelmatrix_diag(k, x0, x1)

# Check flatten/unflatten
ParameterHandling.TestUtils.test_flatten_interface(k)

Check warning on line 90 in src/TestUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/TestUtils.jl#L90

Added line #L90 was not covered by tests

return nothing

Check warning on line 92 in src/TestUtils.jl

View check run for this annotation

Codecov / codecov/patch

src/TestUtils.jl#L92

Added line #L92 was not covered by tests
end

"""
Expand Down
44 changes: 28 additions & 16 deletions src/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@
"""
struct ZeroKernel <: SimpleKernel end

@noparams ZeroKernel

# SimpleKernel interface
kappa(::ZeroKernel, ::Real) = false

metric(::ZeroKernel) = Delta()

# Optimizations
Expand Down Expand Up @@ -68,6 +71,8 @@
"""
struct WhiteKernel <: SimpleKernel end

@noparams WhiteKernel

"""
EyeKernel()

Expand Down Expand Up @@ -95,52 +100,59 @@

See also: [`ZeroKernel`](@ref)
"""
struct ConstantKernel{Tc<:Real} <: SimpleKernel
c::Vector{Tc}
struct ConstantKernel{T<:Real} <: SimpleKernel
c::T

function ConstantKernel(; c::Real=1.0)
function ConstantKernel(c::Real)

Check warning on line 106 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L106

Added line #L106 was not covered by tests
@check_args(ConstantKernel, c, c >= zero(c), "c ≥ 0")
return new{typeof(c)}([c])
return new{typeof(c)}(c)

Check warning on line 108 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L108

Added line #L108 was not covered by tests
end
end

@functor ConstantKernel
ConstantKernel(; c::Real=1.0) = ConstantKernel(c)

Check warning on line 112 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L112

Added line #L112 was not covered by tests

function ParameterHandling.flatten(::Type{T}, k::ConstantKernel{S}) where {T<:Real,S}
function unflatten_to_constantkernel(v::Vector{T})
return ConstantKernel(; c=S(exp(only(v))))

Check warning on line 116 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L114-L116

Added lines #L114 - L116 were not covered by tests
end
return T[log(k.c)], unflatten_to_constantkernel

Check warning on line 118 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L118

Added line #L118 was not covered by tests
end

# SimpleKernel interface
kappa(κ::ConstantKernel, ::Real) = only(κ.c)
kappa(κ::ConstantKernel, ::Real) = κ.c

Check warning on line 122 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L122

Added line #L122 was not covered by tests
metric(::ConstantKernel) = Delta()

# Optimizations
(k::ConstantKernel)(x, y) = only(k.c)
kernelmatrix(k::ConstantKernel, x::AbstractVector) = Fill(only(k.c), length(x), length(x))
(k::ConstantKernel)(x, y) = k.c
kernelmatrix(k::ConstantKernel, x::AbstractVector) = Fill(k.c, length(x), length(x))

Check warning on line 127 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L126-L127

Added lines #L126 - L127 were not covered by tests
function kernelmatrix(k::ConstantKernel, x::AbstractVector, y::AbstractVector)
validate_inputs(x, y)
return Fill(only(k.c), length(x), length(y))
return Fill(k.c, length(x), length(y))

Check warning on line 130 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L130

Added line #L130 was not covered by tests
end
function kernelmatrix!(K::AbstractMatrix, k::ConstantKernel, x::AbstractVector)
validate_inplace_dims(K, x)
return fill!(K, only(k.c))
return fill!(K, k.c)

Check warning on line 134 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L134

Added line #L134 was not covered by tests
end
function kernelmatrix!(
K::AbstractMatrix, k::ConstantKernel, x::AbstractVector, y::AbstractVector
)
validate_inplace_dims(K, x, y)
return fill!(K, only(k.c))
return fill!(K, k.c)

Check warning on line 140 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L140

Added line #L140 was not covered by tests
end
kernelmatrix_diag(k::ConstantKernel, x::AbstractVector) = Fill(only(k.c), length(x))
kernelmatrix_diag(k::ConstantKernel, x::AbstractVector) = Fill(k.c, length(x))

Check warning on line 142 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L142

Added line #L142 was not covered by tests
function kernelmatrix_diag(k::ConstantKernel, x::AbstractVector, y::AbstractVector)
validate_inputs(x, y)
return Fill(only(k.c), length(x))
return Fill(k.c, length(x))

Check warning on line 145 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L145

Added line #L145 was not covered by tests
end
function kernelmatrix_diag!(K::AbstractVector, k::ConstantKernel, x::AbstractVector)
validate_inplace_dims(K, x)
return fill!(K, only(k.c))
return fill!(K, k.c)

Check warning on line 149 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L149

Added line #L149 was not covered by tests
end
function kernelmatrix_diag!(
K::AbstractVector, k::ConstantKernel, x::AbstractVector, y::AbstractVector
)
validate_inplace_dims(K, x, y)
return fill!(K, only(k.c))
return fill!(K, k.c)

Check warning on line 155 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L155

Added line #L155 was not covered by tests
end

Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", only(κ.c), ")")
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", κ.c, ")")

Check warning on line 158 in src/basekernels/constant.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/constant.jl#L158

Added line #L158 was not covered by tests
2 changes: 2 additions & 0 deletions src/basekernels/cosine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ end

CosineKernel(; metric=Euclidean()) = CosineKernel(metric)

@noparams CosineKernel

kappa(::CosineKernel, d::Real) = cospi(d)

metric(k::CosineKernel) = k.metric
Expand Down
27 changes: 19 additions & 8 deletions src/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@

SqExponentialKernel(; metric=Euclidean()) = SqExponentialKernel(metric)

@noparams SqExponentialKernel

kappa(::SqExponentialKernel, d::Real) = exp(-d^2 / 2)
kappa(::SqExponentialKernel{<:Euclidean}, d²::Real) = exp(-d² / 2)

Expand Down Expand Up @@ -76,6 +78,8 @@

ExponentialKernel(; metric=Euclidean()) = ExponentialKernel(metric)

@noparams ExponentialKernel

kappa(::ExponentialKernel, d::Real) = exp(-d)

metric(k::ExponentialKernel) = k.metric
Expand Down Expand Up @@ -121,30 +125,37 @@

[^RW]: C. E. Rasmussen & C. K. I. Williams (2006). Gaussian Processes for Machine Learning.
"""
struct GammaExponentialKernel{<:Real,M} <: SimpleKernel
γ::Vector{Tγ}
struct GammaExponentialKernel{T<:Real,M} <: SimpleKernel
γ::T
metric::M

function GammaExponentialKernel(γ::Real, metric)
@check_args(GammaExponentialKernel, γ, zero(γ) < γ ≤ 2, "γ ∈ (0, 2]")
return new{typeof(γ),typeof(metric)}([γ], metric)
return new{typeof(γ),typeof(metric)}(γ, metric)

Check warning on line 134 in src/basekernels/exponential.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/exponential.jl#L134

Added line #L134 was not covered by tests
end
end

function GammaExponentialKernel(; gamma::Real=1.0, γ::Real=gamma, metric=Euclidean())
return GammaExponentialKernel(γ, metric)
end

@functor GammaExponentialKernel
function ParameterHandling.flatten(

Check warning on line 142 in src/basekernels/exponential.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/exponential.jl#L142

Added line #L142 was not covered by tests
::Type{T}, k::GammaExponentialKernel{S}
) where {T<:Real,S<:Real}
metric = k.metric
function unflatten_to_gammaexponentialkernel(v::Vector{T})
γ = S(2 * logistic(only(v)))
return GammaExponentialKernel(; γ=γ, metric=metric)

Check warning on line 148 in src/basekernels/exponential.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/exponential.jl#L145-L148

Added lines #L145 - L148 were not covered by tests
end
return T[logit(k.γ / 2)], unflatten_to_gammaexponentialkernel

Check warning on line 150 in src/basekernels/exponential.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/exponential.jl#L150

Added line #L150 was not covered by tests
end

kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^only(κ.γ))
kappa(κ::GammaExponentialKernel, d::Real) = exp(-d^κ.γ)

Check warning on line 153 in src/basekernels/exponential.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/exponential.jl#L153

Added line #L153 was not covered by tests

metric(k::GammaExponentialKernel) = k.metric

iskroncompatible(::GammaExponentialKernel) = true

function Base.show(io::IO, κ::GammaExponentialKernel)
return print(
io, "Gamma Exponential Kernel (γ = ", only(κ.γ), ", metric = ", κ.metric, ")"
)
return print(io, "Gamma Exponential Kernel (γ = ", κ.γ, ", metric = ", κ.metric, ")")

Check warning on line 160 in src/basekernels/exponential.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/exponential.jl#L160

Added line #L160 was not covered by tests
end
2 changes: 2 additions & 0 deletions src/basekernels/exponentiated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ k(x, x') = \\exp(x^\\top x').
"""
struct ExponentiatedKernel <: SimpleKernel end

@noparams ExponentiatedKernel

kappa(::ExponentiatedKernel, xᵀy::Real) = exp(xᵀy)

metric(::ExponentiatedKernel) = DotProduct()
Expand Down
13 changes: 10 additions & 3 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,23 @@
```
"""
struct FBMKernel{T<:Real} <: Kernel
h::Vector{T}
h::T

function FBMKernel(h::Real)
@check_args(FBMKernel, h, zero(h) ≤ h ≤ one(h), "h ∈ [0, 1]")
return new{typeof(h)}([h])
return new{typeof(h)}(h)

Check warning on line 20 in src/basekernels/fbm.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/fbm.jl#L20

Added line #L20 was not covered by tests
end
end

FBMKernel(; h::Real=0.5) = FBMKernel(h)

@functor FBMKernel
function ParameterHandling.flatten(::Type{T}, k::FBMKernel{S}) where {T<:Real,S<:Real}
function unflatten_to_fbmkernel(v::Vector{T})
h = S(logistic(only(v)))
return FBMKernel(h)

Check warning on line 29 in src/basekernels/fbm.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/fbm.jl#L26-L29

Added lines #L26 - L29 were not covered by tests
end
return T[logit(k.h)], unflatten_to_fbmkernel
end

Check warning on line 32 in src/basekernels/fbm.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/fbm.jl#L31-L32

Added lines #L31 - L32 were not covered by tests

function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
modX = sum(abs2, x)
Expand Down
16 changes: 12 additions & 4 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,19 +23,23 @@

See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref)
"""
struct MaternKernel{<:Real,M} <: SimpleKernel
ν::Vector{Tν}
struct MaternKernel{T<:Real,M} <: SimpleKernel
ν::T
metric::M

function MaternKernel(ν::Real, metric)
@check_args(MaternKernel, ν, ν > zero(ν), "ν > 0")
return new{typeof(ν),typeof(metric)}([ν], metric)
return new{typeof(ν),typeof(metric)}(ν, metric)
end
end

MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν, metric)

@functor MaternKernel
function ParameterHandling.flatten(::Type{T}, k::MaternKernel{S}) where {T<:Real,S<:Real}
metric = k.metric
unflatten_to_maternkernel(v::Vector{T}) = MaternKernel(S(exp(first(v))), metric)
return T[log(k.ν)], unflatten_to_maternkernel

Check warning on line 41 in src/basekernels/matern.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/matern.jl#L38-L41

Added lines #L38 - L41 were not covered by tests
end

@inline kappa(k::MaternKernel, d::Real) = _matern(only(k.ν), d)

Expand Down Expand Up @@ -80,6 +84,8 @@

Matern32Kernel(; metric=Euclidean()) = Matern32Kernel(metric)

@noparams Matern32Kernel

kappa(::Matern32Kernel, d::Real) = (1 + sqrt(3) * d) * exp(-sqrt(3) * d)

metric(k::Matern32Kernel) = k.metric
Expand Down Expand Up @@ -111,6 +117,8 @@

Matern52Kernel(; metric=Euclidean()) = Matern52Kernel(metric)

@noparams Matern52Kernel

kappa(::Matern52Kernel, d::Real) = (1 + sqrt(5) * d + 5 * d^2 / 3) * exp(-sqrt(5) * d)

metric(k::Matern52Kernel) = k.metric
Expand Down
2 changes: 2 additions & 0 deletions src/basekernels/nn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ for inputs ``x, x' \\in \\mathbb{R}^d``.[^CW]
"""
struct NeuralNetworkKernel <: Kernel end

@noparams NeuralNetworkKernel

function (κ::NeuralNetworkKernel)(x, y)
return asin(dot(x, y) / sqrt((1 + sum(abs2, x)) * (1 + sum(abs2, y))))
end
Expand Down
15 changes: 12 additions & 3 deletions src/basekernels/periodic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,25 @@
end
end

"""
PeriodicKernel(dims::Int)

Create a [`PeriodicKernel`](@ref) with parameter `r=ones(Float64, dims)`.
"""
PeriodicKernel(dims::Int) = PeriodicKernel(Float64, dims)

"""
PeriodicKernel([T=Float64, dims::Int=1])
PeriodicKernel(T, dims::Int=1)

Create a [`PeriodicKernel`](@ref) with parameter `r=ones(T, dims)`.
"""
PeriodicKernel(T::DataType, dims::Int=1) = PeriodicKernel(; r=ones(T, dims))
PeriodicKernel(::Type{T}, dims::Int=1) where {T} = PeriodicKernel(; r=ones(T, dims))

Check warning on line 36 in src/basekernels/periodic.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/periodic.jl#L36

Added line #L36 was not covered by tests

@functor PeriodicKernel
function ParameterHandling.flatten(::Type{T}, k::PeriodicKernel{S}) where {T<:Real,S}
vec = T[log(ri) for ri in k.r]
unflatten_to_periodickernel(v::Vector{T}) = PeriodicKernel(; r=S[exp(vi) for vi in v])
return vec, unflatten_to_periodickernel

Check warning on line 41 in src/basekernels/periodic.jl

View check run for this annotation

Codecov / codecov/patch

src/basekernels/periodic.jl#L38-L41

Added lines #L38 - L41 were not covered by tests
end

metric(κ::PeriodicKernel) = Sinus(κ.r)

Expand Down
2 changes: 2 additions & 0 deletions src/basekernels/piecewisepolynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ function PiecewisePolynomialKernel(; degree::Int=0, kwargs...)
return PiecewisePolynomialKernel{degree}(; kwargs...)
end

@noparams PiecewisePolynomialKernel

piecewise_polynomial_coefficients(::Val{0}, ::Int) = (1,)
piecewise_polynomial_coefficients(::Val{1}, j::Int) = (1, j + 1)
piecewise_polynomial_coefficients(::Val{2}, j::Int) = (1, j + 2, (j^2 + 4 * j)//3 + 1)
Expand Down
Loading
Loading