Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate use of Mahalanobis distance #225

Merged
merged 12 commits into from
Jan 15, 2021
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ jobs:
- name: Format check
run: |
CHANGED="$(git diff --name-only)"
if [ ! -z $CHANGED ]; then
if [ ! -z "$CHANGED" ]; then
>&2 echo "Some files have not been formatted !!!"
echo "$CHANGED"
exit 1
Expand Down
16 changes: 9 additions & 7 deletions docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -153,17 +153,19 @@ where $r$ has the same dimension as $x$ and $r_i > 0$.

## Piecewise Polynomial Kernel

The [`PiecewisePolynomialKernel`](@ref) is defined for $x, x'\in \mathbb{R}^D$, a positive-definite matrix $P \in \mathbb{R}^{D \times D}$, and $V \in \{0,1,2,3\}$ as
The [`PiecewisePolynomialKernel`](@ref) is defined for $x, x' \in \mathbb{R}^d$ and
devmotion marked this conversation as resolved.
Show resolved Hide resolved
$v \in \{0,1,2,3\}$ as
```math
k(x,x'; P, V) = \max(1 - \sqrt{x^\top P x'}, 0)^{j + V} f_V(\sqrt{x^\top P x'}, j),
devmotion marked this conversation as resolved.
Show resolved Hide resolved
k(x, x'; v) = \max(1 - \|x - x'\|, 0)^{j + v} f_v(\|x - x'\|, j),
```
where $j = \lfloor \frac{D}{2}\rfloor + V + 1$, and $f_V$ are polynomials defined as follows:
where $j = \lfloor \frac{d}{2}\rfloor + v + 1$, and $f_v$ are polynomials defined as
follows:
```math
\begin{aligned}
f_0(r, j) &= 1, \\
f_1(r, j) &= 1 + (j + 1) r, \\
f_2(r, j) &= 1 + (j + 2) r + ((j^2 + 4j + 3) / 3) r^2, \\
f_3(r, j) &= 1 + (j + 3) r + ((6 j^2 + 36j + 45) / 15) r^2 + ((j^3 + 9 j^2 + 23j + 15) / 15) r^3.
f_0(r, j) &= 1, \\
f_1(r, j) &= 1 + (j + 1) r, \\
f_2(r, j) &= 1 + (j + 2) r + ((j^2 + 4j + 3) / 3) r^2, \\
f_3(r, j) &= 1 + (j + 3) r + ((6 j^2 + 36j + 45) / 15) r^2 + ((j^3 + 9 j^2 + 23j + 15) / 15) r^3.
\end{aligned}
```

Expand Down
5 changes: 3 additions & 2 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ export FBMKernel
export MaternKernel, Matern12Kernel, Matern32Kernel, Matern52Kernel
export LinearKernel, PolynomialKernel
export RationalQuadraticKernel, GammaRationalQuadraticKernel
export MahalanobisKernel, GaborKernel, PiecewisePolynomialKernel
export GaborKernel, PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct
export TransformedKernel, ScaledKernel
Expand Down Expand Up @@ -90,7 +90,6 @@ include(joinpath("basekernels", "exponential.jl"))
include(joinpath("basekernels", "exponentiated.jl"))
include(joinpath("basekernels", "fbm.jl"))
include(joinpath("basekernels", "gabor.jl"))
include(joinpath("basekernels", "maha.jl"))
include(joinpath("basekernels", "matern.jl"))
include(joinpath("basekernels", "nn.jl"))
include(joinpath("basekernels", "periodic.jl"))
Expand Down Expand Up @@ -118,6 +117,8 @@ include("zygote_adjoints.jl")

include("test_utils.jl")

include("deprecated.jl")

function __init__()
@require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin
include(joinpath("matrix", "kernelkroneckermat.jl"))
Expand Down
75 changes: 42 additions & 33 deletions src/basekernels/piecewisepolynomial.jl
Original file line number Diff line number Diff line change
@@ -1,56 +1,65 @@
"""
devmotion marked this conversation as resolved.
Show resolved Hide resolved
PiecewisePolynomialKernel{V}(maha::AbstractMatrix)
PiecewisePolynomialKernel(; v::Int=0, d::Int)
PiecewisePolynomialKernel{v}(d::Int)

Piecewise Polynomial covariance function with compact support, V = 0,1,2,3.
The kernel functions are 2V times continuously differentiable and the corresponding
processes are hence V times mean-square differentiable. The kernel function is:
Piecewise polynomial kernel with compact support.

The kernel is defined for ``x, x' \\in \\mathbb{R}^d`` and ``v \\in \\{0,1,2,3\\}`` as
```math
k(x, x'; v) = \\max(1 - \\|x - x'\\|, 0)^{j + v} f_v(\\|x - x'\\|, j),
```
where ``j = \\lfloor \\frac{d}{2}\\rfloor + v + 1``, and ``f_v`` are polynomials defined as
follows:
```math
κ(x, y) = max(1 - r, 0)^(j + V) * f(r, j) with j = floor(D / 2) + V + 1
\\begin{aligned}
f_0(r, j) &= 1, \\\\
f_1(r, j) &= 1 + (j + 1) r, \\\\
f_2(r, j) &= 1 + (j + 2) r + ((j^2 + 4j + 3) / 3) r^2, \\\\
f_3(r, j) &= 1 + (j + 3) r + ((6 j^2 + 36j + 45) / 15) r^2 + ((j^3 + 9 j^2 + 23j + 15) / 15) r^3.
\\end{aligned}
```
where `r` is the Mahalanobis distance mahalanobis(x,y) with `maha` as the metric.

The kernel is ``2v`` times continuously differentiable and the corresponding Gaussian
process is hence ``v`` times mean-square differentiable.
"""
struct PiecewisePolynomialKernel{V,A<:AbstractMatrix{<:Real}} <: SimpleKernel
maha::A
struct PiecewisePolynomialKernel{V} <: SimpleKernel
j::Int
function PiecewisePolynomialKernel{V}(maha::AbstractMatrix{<:Real}) where {V}

function PiecewisePolynomialKernel{V}(d::Int) where {V}
V in (0, 1, 2, 3) || error("Invalid parameter V=$(V). Should be 0, 1, 2 or 3.")
LinearAlgebra.checksquare(maha)
j = div(size(maha, 1), 2) + V + 1
return new{V,typeof(maha)}(maha, j)
d > 0 || error("number of dimensions has to be positive")
j = div(d, 2) + V + 1
return new{V}(j)
end
end

function PiecewisePolynomialKernel(; v::Integer=0, maha::AbstractMatrix{<:Real})
return PiecewisePolynomialKernel{v}(maha)
end

# Have to reconstruct the type parameter
# See also https://github.com/FluxML/Functors.jl/issues/3#issuecomment-626747663
function Functors.functor(::Type{<:PiecewisePolynomialKernel{V}}, x) where {V}
function reconstruct_kernel(xs)
return PiecewisePolynomialKernel{V}(xs.maha)
# TODO: remove `maha` keyword argument in next breaking release
function PiecewisePolynomialKernel(; v::Int=0, maha=nothing, d::Int=-1)
if maha !== nothing
Base.depwarn("keyword argument `maha` is deprecated", :PiecewisePolynomialKernel)
d = size(maha, 1)
return transform(PiecewisePolynomialKernel{v}(d), LinearTransform(cholesky(maha).U))
theogf marked this conversation as resolved.
Show resolved Hide resolved
else
return PiecewisePolynomialKernel{v}(d)
end
return (maha=x.maha,), reconstruct_kernel
end

_f(κ::PiecewisePolynomialKernel{0}, r, j) = 1
_f(κ::PiecewisePolynomialKernel{1}, r, j) = 1 + (j + 1) * r
_f(κ::PiecewisePolynomialKernel{2}, r, j) = 1 + (j + 2) * r + (j^2 + 4 * j + 3) / 3 * r .^ 2
function _f(κ::PiecewisePolynomialKernel{3}, r, j)
_f(::PiecewisePolynomialKernel{1}, r, j) = 1 + (j + 1) * r
_f(::PiecewisePolynomialKernel{2}, r, j) = 1 + (j + 2) * r + (j^2 + 4 * j + 3) / 3 * r^2
function _f(::PiecewisePolynomialKernel{3}, r, j)
return 1 +
(j + 3) * r +
(6 * j^2 + 36j + 45) / 15 * r .^ 2 +
(j^3 + 9 * j^2 + 23j + 15) / 15 * r .^ 3
(j + 3) * r +
(6 * j^2 + 36j + 45) / 15 * r ^ 2 +
(j^3 + 9 * j^2 + 23j + 15) / 15 * r ^ 3
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

JuliaFormatter wants this to be formatted

    return 1 +
           (j + 3) * r +
           (6 * j^2 + 36j + 45) / 15 * r ^ 2 +
           (j^3 + 9 * j^2 + 23j + 15) / 15 * r ^ 3

This feels weird since usually we just increase the indentation level but do not align operators or lines. Maybe I missed it but I couldn't find any information about similar cases in BlueStyle - is this a bug in JuliaFormatter?

end

kappa(κ::PiecewisePolynomialKernel{0}, r) = max(1 - r, 0)^κ.j
function kappa(κ::PiecewisePolynomialKernel{V}, r) where {V}
return max(1 - r, 0)^(κ.j + V) * _f(κ, r, κ.j)
end

metric(κ::PiecewisePolynomialKernel) = Mahalanobis(κ.maha)
metric(::PiecewisePolynomialKernel) = Euclidean()

function Base.show(io::IO, κ::PiecewisePolynomialKernel{V}) where {V}
return print(
io, "Piecewise Polynomial Kernel (v = ", V, ", size(maha) = ", size(κ.maha), ")"
)
return print(io, "Piecewise Polynomial Kernel (v = ", V, ", ⌊d/2⌋ = ", κ.j - 1 - V, ")")
end
9 changes: 9 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# TODO: remove tests when removed
@deprecate MahalanobisKernel(; P::AbstractMatrix{<:Real}) transform(
SqExponentialKernel(), LinearTransform(sqrt(2) .* cholesky(P).U)
)

# TODO: remove keyword argument `maha` when removed
@deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform(
PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U)
)
38 changes: 3 additions & 35 deletions test/basekernels/maha.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
@testset "maha" begin
rng = MersenneTwister(123456)
x = 2 * rand(rng)
D_in = 3
v1 = rand(rng, D_in)
v2 = rand(rng, D_in)
Expand All @@ -9,40 +8,10 @@
P = Matrix(Cholesky(U, 'U', 0))
@assert isposdef(P)

k = MahalanobisKernel(; P=P)

@test kappa(k, x) == exp(-x)
k = @test_deprecated MahalanobisKernel(; P=P)
@test k isa TransformedKernel{SqExponentialKernel,<:LinearTransform}
@test k.transform.A ≈ sqrt(2) .* U
@test k(v1, v2) ≈ exp(-sqmahalanobis(v1, v2, P))
@test kappa(ExponentialKernel(), x) == kappa(k, x)
@test repr(k) == "Mahalanobis Kernel (size(P) = $(size(P)))"

M1, M2 = rand(rng, 3, 2), rand(rng, 3, 2)

function FiniteDifferences.to_vec(dist::SqMahalanobis)
return vec(dist.qmat), x -> SqMahalanobis(reshape(x, size(dist.qmat)...))
end
a = rand()

function test_mahakernel(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return MahalanobisKernel(; P=Array(U' * U))(v1, v2)
end

@test all(
FiniteDifferences.j′vp(FDM, test_mahakernel, a, U, v1, v2)[1] .≈
UpperTriangular(Zygote.pullback(test_mahakernel, U, v1, v2)[2](a)[1]),
)

function test_sqmaha(U::UpperTriangular, v1::AbstractVector, v2::AbstractVector)
return SqMahalanobis(Array(U' * U))(v1, v2)
end

@test all(
FiniteDifferences.j′vp(FDM, test_sqmaha, a, U, v1, v2)[1] .≈
UpperTriangular(Zygote.pullback(test_sqmaha, U, v1, v2)[2](a)[1]),
)

# test_ADs(U -> MahalanobisKernel(P=Array(U' * U)), U, ADs=[:Zygote])
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"

# Standardised tests.
@testset "ColVecs" begin
Expand All @@ -57,5 +26,4 @@
x2 = RowVecs(randn(2, D_in))
TestUtils.test_interface(k, x0, x1, x2)
end
test_params(k, (P,))
end
19 changes: 12 additions & 7 deletions test/basekernels/piecewisepolynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,26 @@
v2 = rand(D)
maha = Matrix{Float64}(I, D, D)
v = 3
k = PiecewisePolynomialKernel{v}(maha)

k2 = PiecewisePolynomialKernel(; v=v, maha=maha)
k = PiecewisePolynomialKernel(; v=v, d=D)
k2 = PiecewisePolynomialKernel{v}(D)
k3 = @test_deprecated PiecewisePolynomialKernel{v}(maha)
k4 = @test_deprecated PiecewisePolynomialKernel(; v=v, maha=maha)

@test k2(v1, v2) ≈ k(v1, v2) atol = 1e-5
@test k2(v1, v2) == k(v1, v2)
@test k3(v1, v2) ≈ k(v1, v2)
@test k4(v1, v2) ≈ k(v1, v2)

@test_throws ErrorException PiecewisePolynomialKernel{4}(maha)
@test_throws ErrorException PiecewisePolynomialKernel{4}(D)
@test_throws ErrorException PiecewisePolynomialKernel{v}(-1)

@test repr(k) == "Piecewise Polynomial Kernel (v = $(v), size(maha) = $(size(maha)))"
@test repr(k) == "Piecewise Polynomial Kernel (v = $(v), ⌊d/2⌋ = $(div(D, 2)))"

# Standardised tests.
TestUtils.test_interface(k, ColVecs{Float64}; dim_in=2)
TestUtils.test_interface(k, RowVecs{Float64}; dim_in=2)
# test_ADs(maha-> PiecewisePolynomialKernel(v=2, maha = maha), maha)
@test_broken "Nothing passes (problem with Mahalanobis distance in Distances)"
test_ADs(() -> PiecewisePolynomialKernel{v}(D))

test_params(k, (maha,))
test_params(k, ())
end