Skip to content

Commit

Permalink
Remove transform and test deprecations
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion committed Apr 11, 2021
1 parent 777f7b2 commit 7d8a968
Show file tree
Hide file tree
Showing 16 changed files with 87 additions and 49 deletions.
6 changes: 3 additions & 3 deletions docs/create_kernel_plots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ n_grid = 101
fill(x₀, n_grid, 1)
xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1)

k = transform(SqExponentialKernel(), 1.0)
k = SqExponentialKernel() ScaleTransform(1.0)
K1 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K1;
Expand All @@ -35,7 +35,7 @@ p = heatmap(
)
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png"))

k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1)))
k = PolynomialKernel(; c=0.0, d=2.0) LinearTransform(randn(3, 1))
K3 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
K3;
Expand All @@ -47,7 +47,7 @@ p = heatmap(
savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png"))

k =
0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) +
0.5 * SqExponentialKernel() * (LinearKernel() ScaleTransform(0.5)) +
0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x)))
K4 = kernelmatrix(k, xrange; obsdim=1)
p = heatmap(
Expand Down
24 changes: 14 additions & 10 deletions src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,22 @@ end
::GaborKernel)(x, y) = κ.kernel(x, y)

function _gabor(; ell=nothing, p=nothing)
if ell === nothing
if p === nothing
return SqExponentialKernel() * CosineKernel()
else
return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p)
end
elseif p === nothing
return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel()
ell_transform = if ell === nothing
IdentityTransform()
elseif ell isa Real
ScaleTransform(inv(ell))
else
return transform(SqExponentialKernel(), 1 ./ ell) *
transform(CosineKernel(), 1 ./ p)
ARDTransform(inv.(ell))
end
p_transform = if p === nothing
IdentityTransform()
elseif p isa Real
ScaleTransform(inv(p))
else
ARDTransform(inv.(p))
end

return (SqExponentialKernel() ell_transform) * (CosineKernel() p_transform)
end

function Base.getproperty(k::GaborKernel, v::Symbol)
Expand Down
6 changes: 3 additions & 3 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`.
It is preferred to create kernels with input transformations with [`∘`](@ref), or its
alias [`compose`](@ref), instead of `TransformedKernel` directly since [`∘`](@ref)
allows optimized implementations for specific kernels and transformations.
It is preferred to create kernels with input transformations with `∘` or its alias
`compose` instead of `TransformedKernel` directly since this allows optimized
implementations for specific kernels and transformations.
# Definition
Expand Down
35 changes: 31 additions & 4 deletions src/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@ julia> map(t2 ∘ t1, ColVecs(X)) == ColVecs(A * (l .* X))
true
```
"""
struct ChainTransform{V<:AbstractVector{<:Transform}} <: Transform
transforms::V
struct ChainTransform{T} <: Transform
transforms::T
end

@functor ChainTransform
Expand All @@ -33,10 +33,37 @@ function ChainTransform(v::AbstractVector{<:Type{<:Transform}}, θ::AbstractVect
return ChainTransform(v.(θ))
end

Base.:(t₁::Transform, t₂::Transform) = ChainTransform([t₂, t₁])
Base.:(t::Transform, tc::ChainTransform) = ChainTransform(vcat(tc.transforms, t))
Base.:(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁))
function Base.:(t::Transform, tc::ChainTransform{<:AstractVector{<:Transform}})
return ChainTransform(vcat(tc.transforms, t))
end
function Base.:(tc::ChainTransform{<:AstractVector{<:Transform}}, t::Transform)
return ChainTransform(vcat(t, tc.transforms))
end
function Base.:(
tc1::ChainTransform{<:AstractVector{<:Transform}},
tc2::ChainTransform{<:AstractVector{<:Transform}},
)
return ChainTransform(vcat(tc2.transforms, tc1.transforms))
end
Base.:(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transforms))

$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)

$M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...)
function $M.$op(
k1::$T{<:AbstractVector{<:Kernel}}, k2::$T{<:AbstractVector{<:Kernel}}
)
return $T(vcat(k1.kernels, k2.kernels))
end

$M.$op(k::Kernel, ks::$T) = $T(k, ks.kernels...)
$M.$op(k::Kernel, ks::$T{<:AbstractVector{<:Kernel}}) = $T(vcat(k, ks.kernels))

$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k)
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k))


(t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x)

function _map(t::ChainTransform, x::AbstractVector)
Expand Down
4 changes: 2 additions & 2 deletions test/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p)
@test k(v1, v2) k_manual atol = 1e-5

lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2)
rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2)
lhs_manual = (SqExponentialKernel() ScaleTransform(1 / k.ell))(v1, v2)
rhs_manual = (CosineKernel() ScaleTransform(1 / k.p))(v1, v2)
@test k(v1, v2) lhs_manual * rhs_manual atol = 1e-5

k = GaborKernel()
Expand Down
13 changes: 13 additions & 0 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
@testset "deprecations.jl" begin
p = rand()
v = rand(3)
M = rand(3, 3)
kernel = SqExponentialKernel()

@test (@test_deprecated transform(kernel, LinearTransform(M))) ==
kernel LinearTransform(M)
@test (@test_deprecated transform(kernel ScaleTransform(p), ARDTransform(v))) ==
kernel ARDTransform(v) ScaleTransform(p)
@test (@test_deprecated transform(kernel, p)) == kernel ScaleTransform(p)
@test (@test_deprecated transform(kernel, v)) == kernel ARDTransform(v)
end
20 changes: 6 additions & 14 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,11 @@
k = SqExponentialKernel()
kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2)
@test kt(v1, v2) == transform(k, s)(v1, v2)
@test kt(v1, v2) == (k ScaleTransform(s))(v1, v2)
@test kt(v1, v2) k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) transform(k, ARDTransform(v))(v1, v2) atol = 1e-5
@test ktard(v1, v2) == (k ARDTransform(v))(v1, v2)
@test ktard(v1, v2) == transform(k, v)(v1, v2)
@test ktard(v1, v2) == k(v .* v1, v .* v2)
@test transform(kt, s2)(v1, v2) kt(s2 * v1, s2 * v2)
@test KernelFunctions.kernel(kt) == k
@test (kt s2)(v1, v2) kt(s2 * v1, s2 * v2)
@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))

TestUtils.test_interface(k, Float64)
Expand Down Expand Up @@ -51,15 +46,12 @@
P = rand(3, 2)
c = Chain(Dense(3, 2))

test_params(transform(k, s), (k, [s]))
test_params(transform(k, v), (k, v))
test_params(transform(k, LinearTransform(P)), (k, P))
test_params(transform(k, LinearTransform(P) ScaleTransform(s)), (k, [s], P))
test_params(transform(k, FunctionTransform(c)), (k, c))
test_params(k ScaleTransform(s), (k, [s]))
test_params(k ARDTransform(v), (k, v))
test_params(k LinearTransform(P), (k, P))
test_params(k (LinearTransform(P) ScaleTransform(s)), (k, [s], P))
test_params(k FunctionTransform(c), (k, c))

@test (k (LinearTransform(P') ScaleTransform(s)))(v1, v2) ==
((k LinearTransform(P')) ScaleTransform(s))(v1, v2)
test_params(k LinearTransform(P), (P, k))
test_params(k LinearTransform(P) ScaleTransform(s), ([s], P, k))
test_params(k FunctionTransform(c), (c, k))
end
2 changes: 2 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ include("test_utils.jl")
include("chainrules.jl")
include("zygoterules.jl")

include("deprecations.jl")

@testset "doctests" begin
DocMeta.setdocmeta!(
KernelFunctions,
Expand Down
2 changes: 1 addition & 1 deletion test/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3)))

@test repr(t) == "ARD Transform (dims: $D)"
test_ADs(x -> transform(SEKernel(), exp.(x)), randn(rng, 3))
test_ADs(x -> SEKernel() ARDTransform(exp.(x)), randn(rng, 3))
end
2 changes: 1 addition & 1 deletion test/transform/chaintransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
# Verify printing works as expected.
@test repr(tp tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)"
test_ADs(
x -> transform(SEKernel(), ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
x -> SEKernel() (ScaleTransform(exp(x[1])) ARDTransform(exp.(x[2:4]))),
randn(rng, 4),
)
end
2 changes: 1 addition & 1 deletion test/transform/functiontransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@

@test repr(FunctionTransform(sin)) == "Function Transform: $(sin)"
f(a, x) = sin.(a .* x)
test_ADs(x -> transform(SEKernel(), FunctionTransform(y -> f(x, y))), randn(rng, 3))
test_ADs(x -> SEKernel() FunctionTransform(y -> f(x, y)), randn(rng, 3))
end
2 changes: 1 addition & 1 deletion test/transform/lineartransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,5 @@
@test_throws DimensionMismatch map(t, ColVecs(randn(rng, Din + 1, Dout)))

@test repr(t) == "Linear transform (size(A) = ($Dout, $Din))"
test_ADs(x -> transform(SEKernel(), LinearTransform(x)), randn(rng, 3, 3))
test_ADs(x -> SEKernel() LinearTransform(x), randn(rng, 3, 3))
end
4 changes: 2 additions & 2 deletions test/transform/periodic_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
x = collect(range(0.0, 3.0 / f; length=1_000))

# Construct in the usual way.
k_eq_periodic = transform(PeriodicKernel(; r=[sqrt(0.25)]), f)
k_eq_periodic = PeriodicKernel(; r=[sqrt(0.25)]) ScaleTransform(f)

# Construct using the peridic transform.
k_eq_transform = transform(SqExponentialKernel(), PeriodicTransform(f))
k_eq_transform = SqExponentialKernel() PeriodicTransform(f)

@test kernelmatrix(k_eq_periodic, x) kernelmatrix(k_eq_transform, x)
# TODO - add interface_tests once #159 is merged.
Expand Down
2 changes: 1 addition & 1 deletion test/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,5 @@
@test t.s == [s2]
@test isequal(ScaleTransform(s), ScaleTransform(s))
@test repr(t) == "Scale Transform (s = $(s2))"
test_ADs(x -> transform(SEKernel(), exp(x[1])), randn(rng, 1))
test_ADs(x -> SEKernel() ScaleTransform(exp(x[1])), randn(rng, 1))
end
10 changes: 5 additions & 5 deletions test/transform/selecttransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
@test repr(t) == "Select Transform (dims: $(select2))"
@test repr(ts) == "Select Transform (dims: $(select_symbols2))"

test_ADs(() -> transform(SEKernel(), SelectTransform([1, 2])))
test_ADs(() -> SEKernel() SelectTransform([1, 2]))

X = randn(rng, (4, 3))
A = AxisArray(X; row=[:a, :b, :c, :d], col=[:x, :y, :z])
Expand All @@ -53,10 +53,10 @@
Z = randn(rng, (2, 3))
C = AxisArray(Z; row=[:e, :f], col=[:x, :y, :z])

tx_row = transform(SEKernel(), SelectTransform([1, 2, 4]))
ta_row = transform(SEKernel(), SelectTransform([:a, :b, :d]))
tx_col = transform(SEKernel(), SelectTransform([1, 3]))
ta_col = transform(SEKernel(), SelectTransform([:x, :z]))
tx_row = SEKernel() SelectTransform([1, 2, 4])
ta_row = SEKernel() SelectTransform([:a, :b, :d])
tx_col = SEKernel() SelectTransform([1, 3])
ta_col = SEKernel() SelectTransform([:x, :z])

@test kernelmatrix(tx_row, X; obsdim=2) kernelmatrix(ta_row, A; obsdim=2)
@test kernelmatrix(tx_col, X; obsdim=1) kernelmatrix(ta_col, A; obsdim=1)
Expand Down
2 changes: 1 addition & 1 deletion test/transform/transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,5 +8,5 @@
@test IdentityTransform()(x) == x
@test map(IdentityTransform(), x) == x
end
test_ADs(() -> transform(SEKernel(), IdentityTransform()))
test_ADs(() -> SEKernel() IdentityTransform())
end

0 comments on commit 7d8a968

Please sign in to comment.