Skip to content

Commit

Permalink
Extension of #269: Use \circ and compose and deprecate transform (
Browse files Browse the repository at this point in the history
#276)

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: kaandocal <>
  • Loading branch information
devmotion and github-actions[bot] authored Apr 13, 2021
1 parent 55f4909 commit 589daff
Show file tree
Hide file tree
Showing 20 changed files with 122 additions and 101 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.9.1"
version = "0.9.2"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -20,6 +21,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
[compat]
ChainRulesCore = "0.9"
Compat = "3.7"
CompositionsBase = "0.1"
Distances = "0.10"
Functors = "0.1"
Requires = "1.0.1"
Expand Down
6 changes: 3 additions & 3 deletions docs/create_kernel_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ n_grid = 101
fill(x₀, n_grid, 1)
xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1)

k = transform(SqExponentialKernel(), 1.0)
k = SqExponentialKernel() ScaleTransform(1.0)
K1 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K1;
Expand All @@ -35,7 +35,7 @@ p = heatmap(
)
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png"))

k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1)))
k = PolynomialKernel(; c=0.0, d=2.0) LinearTransform(randn(3, 1))
K3 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K3;
Expand All @@ -47,7 +47,7 @@ p = heatmap(
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png"))

k =
0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) +
0.5 * SqExponentialKernel() * (LinearKernel() ScaleTransform(0.5)) +
0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x)))
K4 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
Expand Down
4 changes: 1 addition & 3 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -118,9 +118,7 @@ of kernels together.

```@docs
TransformedKernel
transform(::Kernel, ::Transform)
transform(::Kernel, ::Real)
transform(::Kernel, ::AbstractVector)
∘(::Kernel, ::Transform)
ScaledKernel
KernelSum
KernelProduct
Expand Down
4 changes: 2 additions & 2 deletions docs/src/transform.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ LowRankTransform(rand(10, 5)) ∘ ScaleTransform(2.0)
A transformation `t` can be applied to a single input `x` with `t(x)` and to multiple inputs
`xs` with `map(t, xs)`.

Kernels can be coupled with input transformations with
[`transform`](@ref). It falls back to creating a [`TransformedKernel`](@ref) but allows more
Kernels can be coupled with input transformations with [``](@ref) or its alias `compose`. It falls
back to creating a [`TransformedKernel`](@ref) but allows more
optimized implementations for specific kernels and transformations.

## List of Input Transforms
Expand Down
5 changes: 4 additions & 1 deletion src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,12 @@ export MOInput
export IndependentMOKernel, LatentFactorMOKernel

# Reexports
export tensor,
export tensor, , compose

using Compat
using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS
using ChainRulesCore: @thunk, InplaceableThunk
using CompositionsBase
using Requires
using Distances, LinearAlgebra
using Functors
Expand Down Expand Up @@ -106,6 +107,8 @@ include("zygoterules.jl")

include("test_utils.jl")

include("deprecations.jl")

function __init__()
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
include(joinpath("matrix", "kernelkroneckermat.jl"))
Expand Down
42 changes: 15 additions & 27 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,11 @@ k(x, x'; l, p) = \\exp\\bigg(- \\cos\\bigg(\\pi\\sum_{i=1}^d \\frac{x_i - x'_i}{
"""
struct GaborKernel{K<:Kernel} <: Kernel
kernel::K

function GaborKernel(; ell=nothing, p=nothing)
k = _gabor(; ell=ell, p=p)
ell_transform = _lengthscale_transform(ell)
p_transform = _lengthscale_transform(p)
k = (SqExponentialKernel() ell_transform) * (CosineKernel() p_transform)
return new{typeof(k)}(k)
end
end
Expand All @@ -24,38 +27,23 @@ end

::GaborKernel)(x, y) = κ.kernel(x, y)

function _gabor(; ell=nothing, p=nothing)
if ell === nothing
if p === nothing
return SqExponentialKernel() * CosineKernel()
else
return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p)
end
elseif p === nothing
return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel()
else
return transform(SqExponentialKernel(), 1 ./ ell) *
transform(CosineKernel(), 1 ./ p)
end
end
_lengthscale_transform(::Nothing) = IdentityTransform()
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))

_lengthscale(::IdentityTransform) = 1
_lengthscale(t::ScaleTransform) = inv(first(t.s))
_lengthscale(t::ARDTransform) = map(inv, t.v)

function Base.getproperty(k::GaborKernel, v::Symbol)
if v == :kernel
return getfield(k, v)
elseif v == :ell
kernel1 = k.kernel.kernels[1]
if kernel1 isa TransformedKernel
return 1 ./ kernel1.transform.s[1]
else
return 1.0
end
ell_transform = k.kernel.kernels[1].transform
return _lengthscale(ell_transform)
elseif v == :p
kernel2 = k.kernel.kernels[2]
if kernel2 isa TransformedKernel
return 1 ./ kernel2.transform.s[1]
else
return 1.0
end
p_transform = k.kernel.kernels[2].transform
return _lengthscale(p_transform)
else
error("Invalid Property")
end
Expand Down
4 changes: 4 additions & 0 deletions src/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
@deprecate transform(k::Kernel, t::Transform) k t
@deprecate transform(k::TransformedKernel, t::Transform) k.kernel t k.transform
@deprecate transform(k::Kernel, ρ::Real) k ScaleTransform(ρ)
@deprecate transform(k::Kernel, ρ::AbstractVector) k ARDTransform(ρ)
65 changes: 33 additions & 32 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,11 @@
Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.
It is preferred to create kernels with input transformations with [`transform`](@ref)
instead of `TransformedKernel` directly since [`transform`](@ref) allows optimized
implementations for specific kernels and transformations.
The preferred way to create kernels with input transformations is to use the composition
operator [`∘`](@ref) or its alias `compose` instead of `TransformedKernel` directly since
this allows optimized implementations for specific kernels and transformations.
# Definition
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by
input transformation ``t`` is defined as
```math
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big).
```
See also: [`∘`](@ref)
"""
struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel
kernel::Tk
Expand Down Expand Up @@ -42,30 +36,37 @@ end
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))

"""
transform(k::Kernel, t::Transform)
kernel ∘ transform
∘(kernel, transform)
compose(kernel, transform)
Create a [`TransformedKernel`](@ref) for kernel `k` and transform `t`.
"""
transform(k::Kernel, t::Transform) = TransformedKernel(k, t)
function transform(k::TransformedKernel, t::Transform)
return TransformedKernel(k.kernel, t k.transform)
end
Compose a `kernel` with a transformation `transform` of its inputs.
"""
transform(k::Kernel, ρ::Real)
The prefix forms support chains of multiple transformations:
`∘(kernel, transform1, transform2) = kernel ∘ transform1 ∘ transform2`.
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscale `ρ`.
"""
transform(k::Kernel, ρ::Real) = transform(k, ScaleTransform(ρ))
# Definition
"""
transform(k::Kernel, ρ::AbstractVector)
For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by
input transformation ``t`` is defined as
```math
\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big).
```
Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscales `ρ`.
"""
transform(k::Kernel, ρ::AbstractVector) = transform(k, ARDTransform(ρ))
# Examples
```jldoctest
julia> (SqExponentialKernel() ∘ ScaleTransform(0.5))(0, 2) == exp(-0.5)
true
kernel(κ) = κ.kernel
julia> ∘(ExponentialKernel(), ScaleTransform(2), ScaleTransform(0.5))(1, 2) == exp(-1)
true
```
See also: [`TransformedKernel`](@ref)
"""
Base.:(k::Kernel, t::Transform) = TransformedKernel(k, t)
Base.:(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform t)

Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0)

Expand All @@ -87,13 +88,13 @@ function kernelmatrix_diag!(
end

function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector)
return kernelmatrix!(K, kernel(κ), _map.transform, x))
return kernelmatrix!(K, κ.kernel, _map.transform, x))
end

function kernelmatrix!(
K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector
)
return kernelmatrix!(K, kernel(κ), _map.transform, x), _map.transform, y))
return kernelmatrix!(K, κ.kernel, _map.transform, x), _map.transform, y))
end

function kernelmatrix_diag::TransformedKernel, x::AbstractVector)
Expand All @@ -105,9 +106,9 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::Abstract
end

function kernelmatrix::TransformedKernel, x::AbstractVector)
return kernelmatrix(kernel(κ), _map.transform, x))
return kernelmatrix(κ.kernel, _map.transform, x))
end

function kernelmatrix::TransformedKernel, x::AbstractVector, y::AbstractVector)
return kernelmatrix(kernel(κ), _map.transform, x), _map.transform, y))
return kernelmatrix(κ.kernel, _map.transform, x), _map.transform, y))
end
6 changes: 3 additions & 3 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,14 @@
k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
@test k(v1, v2) k_manual atol = 1e-5

lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2)
rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2)
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / k.ell))(v1, v2)
rhs_manual = (CosineKernel() ScaleTransform(1 / k.p))(v1, v2)
@test k(v1, v2) lhs_manual * rhs_manual atol = 1e-5

k = GaborKernel()
@test k.ell 1.0 atol = 1e-5
@test k.p 1.0 atol = 1e-5
@test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)"
@test repr(k) == "Gabor Kernel (ell = 1, p = 1)"

test_interface(k, Vector{Float64})

Expand Down
20 changes: 20 additions & 0 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
@testset "deprecations.jl" begin
p = rand()
v = rand(3)
M = rand(3, 3)
v1 = rand(3)
v2 = rand(3)
kernel = SqExponentialKernel()

k1 = @test_deprecated transform(kernel, LinearTransform(M))
@test k1(v1, v2) == (kernel LinearTransform(M))(v1, v2)

k2 = @test_deprecated transform(kernel ScaleTransform(p), ARDTransform(v))
@test k2(v1, v2) == (kernel ARDTransform(v) ScaleTransform(p))(v1, v2)

k3 = @test_deprecated transform(kernel, p)
@test k3(v1, v2) == (kernel ScaleTransform(p))(v1, v2)

k4 = @test_deprecated transform(kernel, v)
@test k4(v1, v2) == (kernel ARDTransform(v))(v1, v2)
end
35 changes: 19 additions & 16 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,26 +5,27 @@
v2 = rand(rng, 3)

s = rand(rng)
s2 = rand(rng)
v = rand(rng, 3)
P = rand(rng, 3, 2)
k = SqExponentialKernel()
kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2)
@test kt(v1, v2) == transform(k, s)(v1, v2)
@test kt(v1, v2) == (k ScaleTransform(s))(v1, v2)
@test kt(v1, v2) k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) transform(k, ARDTransform(v))(v1, v2) atol = 1e-5
@test ktard(v1, v2) == transform(k, v)(v1, v2)
@test ktard(v1, v2) == (k ARDTransform(v))(v1, v2)
@test ktard(v1, v2) == k(v .* v1, v .* v2)
@test transform(kt, s2)(v1, v2) kt(s2 * v1, s2 * v2)
@test KernelFunctions.kernel(kt) == k
@test (k LinearTransform(P') ScaleTransform(s))(v1, v2) ==
((k LinearTransform(P')) ScaleTransform(s))(v1, v2) ==
(k (LinearTransform(P') ScaleTransform(s)))(v1, v2)

@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))

TestUtils.test_interface(k, Float64)
test_ADs(x -> transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff])
test_ADs(x -> SqExponentialKernel() ScaleTransform(x[1]), rand(1))

# Test implicit gradients
@testset "Implicit gradients" begin
k = transform(SqExponentialKernel(), 2.0)
k = SqExponentialKernel() ScaleTransform(2.0)
ps = Flux.params(k)
X = rand(10, 1)
x = vec(X)
Expand All @@ -46,12 +47,14 @@
@test g1[first(ps)] g3[first(ps)]
end

P = rand(3, 2)
c = Chain(Dense(3, 2))
@testset "Parameters" begin
k = ConstantKernel(; c=rand(rng))
c = Chain(Dense(3, 2))

test_params(transform(k, s), (k, [s]))
test_params(transform(k, v), (k, v))
test_params(transform(k, LinearTransform(P)), (k, P))
test_params(transform(k, LinearTransform(P) ScaleTransform(s)), (k, [s], P))
test_params(transform(k, FunctionTransform(c)), (k, c))
test_params(k ScaleTransform(s), (k, [s]))
test_params(k ARDTransform(v), (k, v))
test_params(k LinearTransform(P), (k, P))
test_params(k LinearTransform(P) ScaleTransform(s), (k, [s], P))
test_params(k FunctionTransform(c), (k, c))
end
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ include("test_utils.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("deprecations.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
KernelFunctions,
Expand Down
2 changes: 1 addition & 1 deletion test/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3)))

@test repr(t) == "ARD Transform (dims: $D)"
test_ADs(x -> transform(SEKernel(), exp.(x)), randn(rng, 3))
test_ADs(x -> SEKernel() ARDTransform(exp.(x)), randn(rng, 3))
end
2 changes: 1 addition & 1 deletion test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Verify printing works as expected.
@test repr(tp tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> transform(SEKernel(), ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
x -> SEKernel() (ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
randn(rng, 4),
)
end
2 changes: 1 addition & 1 deletion test/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@

@test repr(FunctionTransform(sin)) == "Function Transform: $(sin)"
f(a, x) = sin.(a .* x)
test_ADs(x -> transform(SEKernel(), FunctionTransform(y -> f(x, y))), randn(rng, 3))
test_ADs(x -> SEKernel() FunctionTransform(y -> f(x, y)), randn(rng, 3))
end
Loading

2 comments on commit 589daff

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/34165

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.2 -m "<description of version>" 589daffaf2f76b519847ca36d344f9e029c2ab93
git push origin v0.9.2

Please sign in to comment.