From 7d8a9682075e92b01156bb427fe9c3d94e299981 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Mon, 12 Apr 2021 00:06:25 +0200 Subject: [PATCH] Remove `transform` and test deprecations --- docs/create_kernel_plots.jl | 6 ++--- src/basekernels/gabor.jl | 24 +++++++++++-------- src/kernels/transformedkernel.jl | 6 ++--- src/transform/chaintransform.jl | 35 ++++++++++++++++++++++++---- test/basekernels/gabor.jl | 4 ++-- test/deprecations.jl | 13 +++++++++++ test/kernels/transformedkernel.jl | 20 +++++----------- test/runtests.jl | 2 ++ test/transform/ardtransform.jl | 2 +- test/transform/chaintransform.jl | 2 +- test/transform/functiontransform.jl | 2 +- test/transform/lineartransform.jl | 2 +- test/transform/periodic_transform.jl | 4 ++-- test/transform/scaletransform.jl | 2 +- test/transform/selecttransform.jl | 10 ++++---- test/transform/transform.jl | 2 +- 16 files changed, 87 insertions(+), 49 deletions(-) create mode 100644 test/deprecations.jl diff --git a/docs/create_kernel_plots.jl b/docs/create_kernel_plots.jl index 52b1532f5..08db1bd1e 100644 --- a/docs/create_kernel_plots.jl +++ b/docs/create_kernel_plots.jl @@ -13,7 +13,7 @@ n_grid = 101 fill(x₀, n_grid, 1) xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1) -k = transform(SqExponentialKernel(), 1.0) +k = SqExponentialKernel() ∘ ScaleTransform(1.0) K1 = kernelmatrix(k, xrange; obsdim=1) p = heatmap( K1; @@ -35,7 +35,7 @@ p = heatmap( ) savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png")) -k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1))) +k = PolynomialKernel(; c=0.0, d=2.0) ∘ LinearTransform(randn(3, 1)) K3 = kernelmatrix(k, xrange; obsdim=1) p = heatmap( K3; @@ -47,7 +47,7 @@ p = heatmap( savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png")) k = - 0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) + + 0.5 * SqExponentialKernel() * (LinearKernel() ∘ ScaleTransform(0.5)) + 0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x))) K4 = kernelmatrix(k, xrange; obsdim=1) p = heatmap( diff --git a/src/basekernels/gabor.jl b/src/basekernels/gabor.jl index 311afcf1b..44c7ea38a 100644 --- a/src/basekernels/gabor.jl +++ b/src/basekernels/gabor.jl @@ -25,18 +25,22 @@ end (κ::GaborKernel)(x, y) = κ.kernel(x, y) function _gabor(; ell=nothing, p=nothing) - if ell === nothing - if p === nothing - return SqExponentialKernel() * CosineKernel() - else - return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p) - end - elseif p === nothing - return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel() + ell_transform = if ell === nothing + IdentityTransform() + elseif ell isa Real + ScaleTransform(inv(ell)) else - return transform(SqExponentialKernel(), 1 ./ ell) * - transform(CosineKernel(), 1 ./ p) + ARDTransform(inv.(ell)) end + p_transform = if p === nothing + IdentityTransform() + elseif p isa Real + ScaleTransform(inv(p)) + else + ARDTransform(inv.(p)) + end + + return (SqExponentialKernel() ∘ ell_transform) * (CosineKernel() ∘ p_transform) end function Base.getproperty(k::GaborKernel, v::Symbol) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index f1afccb55..374a17b23 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -3,9 +3,9 @@ Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`. -It is preferred to create kernels with input transformations with [`∘`](@ref), or its -alias [`compose`](@ref), instead of `TransformedKernel` directly since [`∘`](@ref) -allows optimized implementations for specific kernels and transformations. +It is preferred to create kernels with input transformations with `∘` or its alias +`compose` instead of `TransformedKernel` directly since this allows optimized +implementations for specific kernels and transformations. # Definition diff --git a/src/transform/chaintransform.jl b/src/transform/chaintransform.jl index bd4627b19..4d1f643ae 100644 --- a/src/transform/chaintransform.jl +++ b/src/transform/chaintransform.jl @@ -19,8 +19,8 @@ julia> map(t2 ∘ t1, ColVecs(X)) == ColVecs(A * (l .* X)) true ``` """ -struct ChainTransform{V<:AbstractVector{<:Transform}} <: Transform - transforms::V +struct ChainTransform{T} <: Transform + transforms::T end @functor ChainTransform @@ -33,10 +33,37 @@ function ChainTransform(v::AbstractVector{<:Type{<:Transform}}, θ::AbstractVect return ChainTransform(v.(θ)) end -Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform([t₂, t₁]) -Base.:∘(t::Transform, tc::ChainTransform) = ChainTransform(vcat(tc.transforms, t)) +Base.:∘(t₁::Transform, t₂::Transform) = ChainTransform((t₂, t₁)) +function Base.:∘(t::Transform, tc::ChainTransform{<:AstractVector{<:Transform}}) + return ChainTransform(vcat(tc.transforms, t)) +end +function Base.:∘(tc::ChainTransform{<:AstractVector{<:Transform}}, t::Transform) + return ChainTransform(vcat(t, tc.transforms)) +end +function Base.:∘( + tc1::ChainTransform{<:AstractVector{<:Transform}}, + tc2::ChainTransform{<:AstractVector{<:Transform}}, +) + return ChainTransform(vcat(tc2.transforms, tc1.transforms)) +end Base.:∘(tc::ChainTransform, t::Transform) = ChainTransform(vcat(t, tc.transforms)) + $M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2) + + $M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...) + function $M.$op( + k1::$T{<:AbstractVector{<:Kernel}}, k2::$T{<:AbstractVector{<:Kernel}} + ) + return $T(vcat(k1.kernels, k2.kernels)) + end + + $M.$op(k::Kernel, ks::$T) = $T(k, ks.kernels...) + $M.$op(k::Kernel, ks::$T{<:AbstractVector{<:Kernel}}) = $T(vcat(k, ks.kernels)) + + $M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k) + $M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k)) + + (t::ChainTransform)(x) = foldl((x, t) -> t(x), t.transforms; init=x) function _map(t::ChainTransform, x::AbstractVector) diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index 472327d87..548f2f584 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -10,8 +10,8 @@ k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p) @test k(v1, v2) ≈ k_manual atol = 1e-5 - lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2) - rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2) + lhs_manual = (SqExponentialKernel() ∘ ScaleTransform(1 / k.ell))(v1, v2) + rhs_manual = (CosineKernel() ∘ ScaleTransform(1 / k.p))(v1, v2) @test k(v1, v2) ≈ lhs_manual * rhs_manual atol = 1e-5 k = GaborKernel() diff --git a/test/deprecations.jl b/test/deprecations.jl new file mode 100644 index 000000000..ff1933168 --- /dev/null +++ b/test/deprecations.jl @@ -0,0 +1,13 @@ +@testset "deprecations.jl" begin + p = rand() + v = rand(3) + M = rand(3, 3) + kernel = SqExponentialKernel() + + @test (@test_deprecated transform(kernel, LinearTransform(M))) == + kernel ∘ LinearTransform(M) + @test (@test_deprecated transform(kernel ∘ ScaleTransform(p), ARDTransform(v))) == + kernel ∘ ARDTransform(v) ∘ ScaleTransform(p) + @test (@test_deprecated transform(kernel, p)) == kernel ∘ ScaleTransform(p) + @test (@test_deprecated transform(kernel, v)) == kernel ∘ ARDTransform(v) +end diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index ef2328488..6ac8c7407 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -10,16 +10,11 @@ k = SqExponentialKernel() kt = TransformedKernel(k, ScaleTransform(s)) ktard = TransformedKernel(k, ARDTransform(v)) - @test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2) - @test kt(v1, v2) == transform(k, s)(v1, v2) @test kt(v1, v2) == (k ∘ ScaleTransform(s))(v1, v2) @test kt(v1, v2) ≈ k(s * v1, s * v2) atol = 1e-5 - @test ktard(v1, v2) ≈ transform(k, ARDTransform(v))(v1, v2) atol = 1e-5 @test ktard(v1, v2) == (k ∘ ARDTransform(v))(v1, v2) - @test ktard(v1, v2) == transform(k, v)(v1, v2) @test ktard(v1, v2) == k(v .* v1, v .* v2) - @test transform(kt, s2)(v1, v2) ≈ kt(s2 * v1, s2 * v2) - @test KernelFunctions.kernel(kt) == k + @test (kt ∘ s2)(v1, v2) ≈ kt(s2 * v1, s2 * v2) @test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s)) TestUtils.test_interface(k, Float64) @@ -51,15 +46,12 @@ P = rand(3, 2) c = Chain(Dense(3, 2)) - test_params(transform(k, s), (k, [s])) - test_params(transform(k, v), (k, v)) - test_params(transform(k, LinearTransform(P)), (k, P)) - test_params(transform(k, LinearTransform(P) ∘ ScaleTransform(s)), (k, [s], P)) - test_params(transform(k, FunctionTransform(c)), (k, c)) + test_params(k ∘ ScaleTransform(s), (k, [s])) + test_params(k ∘ ARDTransform(v), (k, v)) + test_params(k ∘ LinearTransform(P), (k, P)) + test_params(k ∘ (LinearTransform(P) ∘ ScaleTransform(s)), (k, [s], P)) + test_params(k ∘ FunctionTransform(c), (k, c)) @test (k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2) == ((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2) - test_params(k ∘ LinearTransform(P), (P, k)) - test_params(k ∘ LinearTransform(P) ∘ ScaleTransform(s), ([s], P, k)) - test_params(k ∘ FunctionTransform(c), (c, k)) end diff --git a/test/runtests.jl b/test/runtests.jl index d092b3901..9bd1cbfe9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -148,6 +148,8 @@ include("test_utils.jl") include("chainrules.jl") include("zygoterules.jl") + include("deprecations.jl") + @testset "doctests" begin DocMeta.setdocmeta!( KernelFunctions, diff --git a/test/transform/ardtransform.jl b/test/transform/ardtransform.jl index 4c0df4deb..26c9969c4 100644 --- a/test/transform/ardtransform.jl +++ b/test/transform/ardtransform.jl @@ -42,5 +42,5 @@ @test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3))) @test repr(t) == "ARD Transform (dims: $D)" - test_ADs(x -> transform(SEKernel(), exp.(x)), randn(rng, 3)) + test_ADs(x -> SEKernel() ∘ ARDTransform(exp.(x)), randn(rng, 3)) end diff --git a/test/transform/chaintransform.jl b/test/transform/chaintransform.jl index dc41cc7bb..6e3d8a44f 100644 --- a/test/transform/chaintransform.jl +++ b/test/transform/chaintransform.jl @@ -23,7 +23,7 @@ # Verify printing works as expected. @test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)" test_ADs( - x -> transform(SEKernel(), ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))), + x -> SEKernel() ∘ (ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))), randn(rng, 4), ) end diff --git a/test/transform/functiontransform.jl b/test/transform/functiontransform.jl index 1dd36b5b5..7d9581ecc 100644 --- a/test/transform/functiontransform.jl +++ b/test/transform/functiontransform.jl @@ -28,5 +28,5 @@ @test repr(FunctionTransform(sin)) == "Function Transform: $(sin)" f(a, x) = sin.(a .* x) - test_ADs(x -> transform(SEKernel(), FunctionTransform(y -> f(x, y))), randn(rng, 3)) + test_ADs(x -> SEKernel() ∘ FunctionTransform(y -> f(x, y)), randn(rng, 3)) end diff --git a/test/transform/lineartransform.jl b/test/transform/lineartransform.jl index ea280d85d..5b22a4d2d 100644 --- a/test/transform/lineartransform.jl +++ b/test/transform/lineartransform.jl @@ -42,5 +42,5 @@ @test_throws DimensionMismatch map(t, ColVecs(randn(rng, Din + 1, Dout))) @test repr(t) == "Linear transform (size(A) = ($Dout, $Din))" - test_ADs(x -> transform(SEKernel(), LinearTransform(x)), randn(rng, 3, 3)) + test_ADs(x -> SEKernel() ∘ LinearTransform(x), randn(rng, 3, 3)) end diff --git a/test/transform/periodic_transform.jl b/test/transform/periodic_transform.jl index 1f7119ac3..324cba907 100644 --- a/test/transform/periodic_transform.jl +++ b/test/transform/periodic_transform.jl @@ -5,10 +5,10 @@ x = collect(range(0.0, 3.0 / f; length=1_000)) # Construct in the usual way. - k_eq_periodic = transform(PeriodicKernel(; r=[sqrt(0.25)]), f) + k_eq_periodic = PeriodicKernel(; r=[sqrt(0.25)]) ∘ ScaleTransform(f) # Construct using the peridic transform. - k_eq_transform = transform(SqExponentialKernel(), PeriodicTransform(f)) + k_eq_transform = SqExponentialKernel() ∘ PeriodicTransform(f) @test kernelmatrix(k_eq_periodic, x) ≈ kernelmatrix(k_eq_transform, x) # TODO - add interface_tests once #159 is merged. diff --git a/test/transform/scaletransform.jl b/test/transform/scaletransform.jl index d416e3eb6..c0445b8a4 100644 --- a/test/transform/scaletransform.jl +++ b/test/transform/scaletransform.jl @@ -19,5 +19,5 @@ @test t.s == [s2] @test isequal(ScaleTransform(s), ScaleTransform(s)) @test repr(t) == "Scale Transform (s = $(s2))" - test_ADs(x -> transform(SEKernel(), exp(x[1])), randn(rng, 1)) + test_ADs(x -> SEKernel() ∘ ScaleTransform(exp(x[1])), randn(rng, 1)) end diff --git a/test/transform/selecttransform.jl b/test/transform/selecttransform.jl index 966fecf7c..c9888443c 100644 --- a/test/transform/selecttransform.jl +++ b/test/transform/selecttransform.jl @@ -44,7 +44,7 @@ @test repr(t) == "Select Transform (dims: $(select2))" @test repr(ts) == "Select Transform (dims: $(select_symbols2))" - test_ADs(() -> transform(SEKernel(), SelectTransform([1, 2]))) + test_ADs(() -> SEKernel() ∘ SelectTransform([1, 2])) X = randn(rng, (4, 3)) A = AxisArray(X; row=[:a, :b, :c, :d], col=[:x, :y, :z]) @@ -53,10 +53,10 @@ Z = randn(rng, (2, 3)) C = AxisArray(Z; row=[:e, :f], col=[:x, :y, :z]) - tx_row = transform(SEKernel(), SelectTransform([1, 2, 4])) - ta_row = transform(SEKernel(), SelectTransform([:a, :b, :d])) - tx_col = transform(SEKernel(), SelectTransform([1, 3])) - ta_col = transform(SEKernel(), SelectTransform([:x, :z])) + tx_row = SEKernel() ∘ SelectTransform([1, 2, 4]) + ta_row = SEKernel() ∘ SelectTransform([:a, :b, :d]) + tx_col = SEKernel() ∘ SelectTransform([1, 3]) + ta_col = SEKernel() ∘ SelectTransform([:x, :z]) @test kernelmatrix(tx_row, X; obsdim=2) ≈ kernelmatrix(ta_row, A; obsdim=2) @test kernelmatrix(tx_col, X; obsdim=1) ≈ kernelmatrix(ta_col, A; obsdim=1) diff --git a/test/transform/transform.jl b/test/transform/transform.jl index 83d63ffe7..bd3ff567d 100644 --- a/test/transform/transform.jl +++ b/test/transform/transform.jl @@ -8,5 +8,5 @@ @test IdentityTransform()(x) == x @test map(IdentityTransform(), x) == x end - test_ADs(() -> transform(SEKernel(), IdentityTransform())) + test_ADs(() -> SEKernel() ∘ IdentityTransform()) end