Skip to content

Commit

Permalink
Add gaborkernel and simplify kernels with IdentityTransforms (#285)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored May 7, 2021
1 parent 3d8b245 commit a99a3e7
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 22 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.9.6"
version = "0.9.7"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
1 change: 1 addition & 0 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ FBMKernel
### Gabor Kernel

```@docs
gaborkernel
GaborKernel
```

Expand Down
3 changes: 2 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ export Transform,

export NystromFact, nystrom

export gaborkernel
export spectral_mixture_kernel, spectral_mixture_product_kernel

export ColVecs, RowVecs
Expand Down Expand Up @@ -73,6 +74,7 @@ include(joinpath("transform", "functiontransform.jl"))
include(joinpath("transform", "selecttransform.jl"))
include(joinpath("transform", "chaintransform.jl"))
include(joinpath("transform", "periodic_transform.jl"))
include(joinpath("kernels", "transformedkernel.jl"))

include(joinpath("basekernels", "constant.jl"))
include(joinpath("basekernels", "cosine.jl"))
Expand All @@ -89,7 +91,6 @@ include(joinpath("basekernels", "rational.jl"))
include(joinpath("basekernels", "sm.jl"))
include(joinpath("basekernels", "wiener.jl"))

include(joinpath("kernels", "transformedkernel.jl"))
include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("kernels", "normalizedkernel.jl"))
include(joinpath("matrix", "kernelmatrix.jl"))
Expand Down
45 changes: 40 additions & 5 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,30 @@
"""
gaborkernel(;
sqexponential_transform=IdentityTransform(), cosine_tranform=IdentityTransform()
)
Construct a Gabor kernel with transformations `sqexponential_transform` and
`cosine_transform` of the inputs of the underlying squared exponential and cosine kernel,
respectively.
# Definition
For inputs ``x, x' \\in \\mathbb{R}^d``, the Gabor kernel with transformations ``f``
and ``g`` of the inputs to the squared exponential and cosine kernel, respectively,
is defined as
```math
k(x, x'; f, g) = \\exp\\bigg(- \\frac{\\| f(x) - f(x')\\|_2^2}{2}\\bigg)
\\cos\\big(\\pi \\|g(x) - g(x')\\|_2 \\big).
```
"""
function gaborkernel(;
sqexponential_transform=IdentityTransform(), cosine_transform=IdentityTransform()
)
return (SqExponentialKernel() sqexponential_transform) *
(CosineKernel() cosine_transform)
end

# everything below will be removed
"""
GaborKernel(; ell::Real=1.0, p::Real=1.0)
Expand All @@ -11,11 +38,20 @@ and period ``p_i > 0`` is defined as
k(x, x'; l, p) = \\exp\\bigg(- \\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{2l_i^2}\\bigg)
\\cos\\bigg(\\pi \\bigg(\\sum_{i=1}^d \\frac{(x_i - x'_i)^2}{p_i^2} \\bigg)^{1/2}\\bigg).
```
!!! note
`GaborKernel` is deprecated and will be removed. Gabor kernels should be
constructed with [`gaborkernel`](@ref) instead.
"""
struct GaborKernel{K<:Kernel} <: Kernel
kernel::K

function GaborKernel(; ell=nothing, p=nothing)
Base.depwarn(
"`GaborKernel` is deprecated and will be removed. Gabor kernels should be " *
"constructed with `gaborkernel` instead.",
:GaborKernel,
)
ell_transform = _lengthscale_transform(ell)
p_transform = _lengthscale_transform(p)
k = (SqExponentialKernel() ell_transform) * (CosineKernel() p_transform)
Expand All @@ -31,19 +67,18 @@ _lengthscale_transform(::Nothing) = IdentityTransform()
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))

_lengthscale(::IdentityTransform) = 1
_lengthscale(x) = 1
_lengthscale(k::TransformedKernel) = _lengthscale(k.transform)
_lengthscale(t::ScaleTransform) = inv(first(t.s))
_lengthscale(t::ARDTransform) = map(inv, t.v)

function Base.getproperty(k::GaborKernel, v::Symbol)
if v == :kernel
return getfield(k, v)
elseif v == :ell
ell_transform = k.kernel.kernels[1].transform
return _lengthscale(ell_transform)
return _lengthscale(k.kernel.kernels[1])
elseif v == :p
p_transform = k.kernel.kernels[2].transform
return _lengthscale(p_transform)
return _lengthscale(k.kernel.kernels[2])
else
error("Invalid Property")
end
Expand Down
4 changes: 4 additions & 0 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@ See also: [`TransformedKernel`](@ref)
Base.:(k::Kernel, t::Transform) = TransformedKernel(k, t)
Base.:(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform t)

# Simplify kernels with identity transformation of the inputs
Base.:(k::Kernel, ::IdentityTransform) = k
Base.:(k::TransformedKernel, ::IdentityTransform) = k

Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)

function printshifted(io::IO, κ::TransformedKernel, shift::Int)
Expand Down
55 changes: 40 additions & 15 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
@@ -1,25 +1,50 @@
@testset "Gabor" begin
v1 = rand(3)
v2 = rand(3)
ell = abs(rand())
p = abs(rand())
k = GaborKernel(; ell=ell, p=p)
@test k.ell ell atol = 1e-5
@test k.p p atol = 1e-5
ell = rand()
p = rand()
k = gaborkernel(;
sqexponential_transform=ScaleTransform(inv(ell)),
cosine_transform=ScaleTransform(inv(p)),
)
@test k isa KernelProduct{
<:Tuple{
TransformedKernel{SqExponentialKernel,<:ScaleTransform},
TransformedKernel{CosineKernel,<:ScaleTransform},
},
}
@test k.kernels[1].transform.s[1] == inv(ell)
@test k.kernels[2].transform.s[1] == inv(p)

k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
@test k(v1, v2) k_manual atol = 1e-5
k_manual = exp(-sqeuclidean(v1, v2) / (2 * ell^2)) * cospi(euclidean(v1, v2) / p)
@test k_manual k(v1, v2) atol = 1e-5

lhs_manual = (SqExponentialKernel() ScaleTransform(1 / k.ell))(v1, v2)
rhs_manual = (CosineKernel() ScaleTransform(1 / k.p))(v1, v2)
@test k(v1, v2) lhs_manual * rhs_manual atol = 1e-5
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / ell))(v1, v2)
rhs_manual = (CosineKernel() ScaleTransform(1 / p))(v1, v2)
@test lhs_manual * rhs_manual k(v1, v2) atol = 1e-5

k = GaborKernel()
@test k.ell 1.0 atol = 1e-5
@test k.p 1.0 atol = 1e-5
@test repr(k) == "Gabor Kernel (ell = 1, p = 1)"
@test gaborkernel() isa KernelProduct{Tuple{SqExponentialKernel,CosineKernel}}

test_interface(k, Vector{Float64})
test_ADs(
x -> gaborkernel(;
sqexponential_transform=ScaleTransform(x[1]),
cosine_transform=ScaleTransform(x[2]),
),
[ell, p],
)

# deprecated `GaborKernel`
k2 = @test_deprecated GaborKernel(; ell=ell, p=p)
@test k2.ell ell atol = 1e-5
@test k2.p p atol = 1e-5
@test k2(v1, v2) k(v1, v2)

k3 = @test_deprecated GaborKernel()
@test k3.ell 1.0 atol = 1e-5
@test k3.p 1.0 atol = 1e-5
@test repr(k3) == "Gabor Kernel (ell = 1, p = 1)"

test_interface(k3, Vector{Float64})

test_ADs(x -> GaborKernel(; ell=x[1], p=x[2]), [ell, p]; ADs=[:Zygote])

Expand Down
4 changes: 4 additions & 0 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,12 @@
v = rand(rng, 3)
P = rand(rng, 3, 2)
k = SqExponentialKernel()
@test k IdentityTransform() === k

kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt IdentityTransform() === kt
@test ktard IdentityTransform() === ktard
@test kt(v1, v2) == (k ScaleTransform(s))(v1, v2)
@test kt(v1, v2) k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) == (k ARDTransform(v))(v1, v2)
Expand Down

2 comments on commit a99a3e7

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/36274

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.7 -m "<description of version>" a99a3e7b1a5c3cbfb328e0e98e8b94ffa9331c6f
git push origin v0.9.7

Please sign in to comment.