From 589daffaf2f76b519847ca36d344f9e029c2ab93 Mon Sep 17 00:00:00 2001 From: David Widmann Date: Tue, 13 Apr 2021 08:36:13 +0200 Subject: [PATCH] Extension of #269: Use `\circ` and `compose` and deprecate `transform` (#276) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: kaandocal <> --- Project.toml | 4 +- docs/create_kernel_plots.jl | 6 +-- docs/src/kernels.md | 4 +- docs/src/transform.md | 4 +- src/KernelFunctions.jl | 5 ++- src/basekernels/gabor.jl | 42 +++++++----------- src/deprecations.jl | 4 ++ src/kernels/transformedkernel.jl | 65 ++++++++++++++-------------- test/basekernels/gabor.jl | 6 +-- test/deprecations.jl | 20 +++++++++ test/kernels/transformedkernel.jl | 35 ++++++++------- test/runtests.jl | 2 + test/transform/ardtransform.jl | 2 +- test/transform/chaintransform.jl | 2 +- test/transform/functiontransform.jl | 2 +- test/transform/lineartransform.jl | 2 +- test/transform/periodic_transform.jl | 4 +- test/transform/scaletransform.jl | 2 +- test/transform/selecttransform.jl | 10 ++--- test/transform/transform.jl | 2 +- 20 files changed, 122 insertions(+), 101 deletions(-) create mode 100644 src/deprecations.jl create mode 100644 test/deprecations.jl diff --git a/Project.toml b/Project.toml index 215863d23..f238cb318 100644 --- a/Project.toml +++ b/Project.toml @@ -1,10 +1,11 @@ name = "KernelFunctions" uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392" -version = "0.9.1" +version = "0.9.2" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +CompositionsBase = "a33af91c-f02d-484b-be07-31d278c5ca2b" Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7" Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" @@ -20,6 +21,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] ChainRulesCore = "0.9" Compat = "3.7" +CompositionsBase = "0.1" Distances = "0.10" Functors = "0.1" Requires = "1.0.1" diff --git a/docs/create_kernel_plots.jl b/docs/create_kernel_plots.jl index 52b1532f5..08db1bd1e 100644 --- a/docs/create_kernel_plots.jl +++ b/docs/create_kernel_plots.jl @@ -13,7 +13,7 @@ n_grid = 101 fill(x₀, n_grid, 1) xrange = reshape(collect(range(-3, 3; length=n_grid)), :, 1) -k = transform(SqExponentialKernel(), 1.0) +k = SqExponentialKernel() ∘ ScaleTransform(1.0) K1 = kernelmatrix(k, xrange; obsdim=1) p = heatmap( K1; @@ -35,7 +35,7 @@ p = heatmap( ) savefig(joinpath(@__DIR__, "src", "assets", "heatmap_matern.png")) -k = transform(PolynomialKernel(; c=0.0, d=2.0), LinearTransform(randn(3, 1))) +k = PolynomialKernel(; c=0.0, d=2.0) ∘ LinearTransform(randn(3, 1)) K3 = kernelmatrix(k, xrange; obsdim=1) p = heatmap( K3; @@ -47,7 +47,7 @@ p = heatmap( savefig(joinpath(@__DIR__, "src", "assets", "heatmap_poly.png")) k = - 0.5 * SqExponentialKernel() * transform(LinearKernel(), 0.5) + + 0.5 * SqExponentialKernel() * (LinearKernel() ∘ ScaleTransform(0.5)) + 0.4 * (@kernel Matern32Kernel() FunctionTransform(x -> sin.(x))) K4 = kernelmatrix(k, xrange; obsdim=1) p = heatmap( diff --git a/docs/src/kernels.md b/docs/src/kernels.md index ee769eb3c..229ac404d 100644 --- a/docs/src/kernels.md +++ b/docs/src/kernels.md @@ -118,9 +118,7 @@ of kernels together. ```@docs TransformedKernel -transform(::Kernel, ::Transform) -transform(::Kernel, ::Real) -transform(::Kernel, ::AbstractVector) +∘(::Kernel, ::Transform) ScaledKernel KernelSum KernelProduct diff --git a/docs/src/transform.md b/docs/src/transform.md index 5bd858dff..9ed878a84 100644 --- a/docs/src/transform.md +++ b/docs/src/transform.md @@ -18,8 +18,8 @@ LowRankTransform(rand(10, 5)) ∘ ScaleTransform(2.0) A transformation `t` can be applied to a single input `x` with `t(x)` and to multiple inputs `xs` with `map(t, xs)`. -Kernels can be coupled with input transformations with -[`transform`](@ref). It falls back to creating a [`TransformedKernel`](@ref) but allows more +Kernels can be coupled with input transformations with [`∘`](@ref) or its alias `compose`. It falls +back to creating a [`TransformedKernel`](@ref) but allows more optimized implementations for specific kernels and transformations. ## List of Input Transforms diff --git a/src/KernelFunctions.jl b/src/KernelFunctions.jl index 2cdfd71c5..9b79eaf83 100644 --- a/src/KernelFunctions.jl +++ b/src/KernelFunctions.jl @@ -39,11 +39,12 @@ export MOInput export IndependentMOKernel, LatentFactorMOKernel # Reexports -export tensor, ⊗ +export tensor, ⊗, compose using Compat using ChainRulesCore: ChainRulesCore, Composite, Zero, One, DoesNotExist, NO_FIELDS using ChainRulesCore: @thunk, InplaceableThunk +using CompositionsBase using Requires using Distances, LinearAlgebra using Functors @@ -106,6 +107,8 @@ include("zygoterules.jl") include("test_utils.jl") +include("deprecations.jl") + function __init__() @require Kronecker = "2c470bb0-bcc8-11e8-3dad-c9649493f05e" begin include(joinpath("matrix", "kernelkroneckermat.jl")) diff --git a/src/basekernels/gabor.jl b/src/basekernels/gabor.jl index 311afcf1b..2d048c779 100644 --- a/src/basekernels/gabor.jl +++ b/src/basekernels/gabor.jl @@ -14,8 +14,11 @@ k(x, x'; l, p) = \\exp\\bigg(- \\cos\\bigg(\\pi\\sum_{i=1}^d \\frac{x_i - x'_i}{ """ struct GaborKernel{K<:Kernel} <: Kernel kernel::K + function GaborKernel(; ell=nothing, p=nothing) - k = _gabor(; ell=ell, p=p) + ell_transform = _lengthscale_transform(ell) + p_transform = _lengthscale_transform(p) + k = (SqExponentialKernel() ∘ ell_transform) * (CosineKernel() ∘ p_transform) return new{typeof(k)}(k) end end @@ -24,38 +27,23 @@ end (κ::GaborKernel)(x, y) = κ.kernel(x, y) -function _gabor(; ell=nothing, p=nothing) - if ell === nothing - if p === nothing - return SqExponentialKernel() * CosineKernel() - else - return SqExponentialKernel() * transform(CosineKernel(), 1 ./ p) - end - elseif p === nothing - return transform(SqExponentialKernel(), 1 ./ ell) * CosineKernel() - else - return transform(SqExponentialKernel(), 1 ./ ell) * - transform(CosineKernel(), 1 ./ p) - end -end +_lengthscale_transform(::Nothing) = IdentityTransform() +_lengthscale_transform(x::Real) = ScaleTransform(inv(x)) +_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x)) + +_lengthscale(::IdentityTransform) = 1 +_lengthscale(t::ScaleTransform) = inv(first(t.s)) +_lengthscale(t::ARDTransform) = map(inv, t.v) function Base.getproperty(k::GaborKernel, v::Symbol) if v == :kernel return getfield(k, v) elseif v == :ell - kernel1 = k.kernel.kernels[1] - if kernel1 isa TransformedKernel - return 1 ./ kernel1.transform.s[1] - else - return 1.0 - end + ell_transform = k.kernel.kernels[1].transform + return _lengthscale(ell_transform) elseif v == :p - kernel2 = k.kernel.kernels[2] - if kernel2 isa TransformedKernel - return 1 ./ kernel2.transform.s[1] - else - return 1.0 - end + p_transform = k.kernel.kernels[2].transform + return _lengthscale(p_transform) else error("Invalid Property") end diff --git a/src/deprecations.jl b/src/deprecations.jl new file mode 100644 index 000000000..c3d7174b9 --- /dev/null +++ b/src/deprecations.jl @@ -0,0 +1,4 @@ +@deprecate transform(k::Kernel, t::Transform) k ∘ t +@deprecate transform(k::TransformedKernel, t::Transform) k.kernel ∘ t ∘ k.transform +@deprecate transform(k::Kernel, ρ::Real) k ∘ ScaleTransform(ρ) +@deprecate transform(k::Kernel, ρ::AbstractVector) k ∘ ARDTransform(ρ) diff --git a/src/kernels/transformedkernel.jl b/src/kernels/transformedkernel.jl index 6cf693dca..0afd157fb 100644 --- a/src/kernels/transformedkernel.jl +++ b/src/kernels/transformedkernel.jl @@ -3,17 +3,11 @@ Kernel derived from `k` for which inputs are transformed via a [`Transform`](@ref) `t`. -It is preferred to create kernels with input transformations with [`transform`](@ref) -instead of `TransformedKernel` directly since [`transform`](@ref) allows optimized -implementations for specific kernels and transformations. +The preferred way to create kernels with input transformations is to use the composition +operator [`∘`](@ref) or its alias `compose` instead of `TransformedKernel` directly since +this allows optimized implementations for specific kernels and transformations. -# Definition - -For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by -input transformation ``t`` is defined as -```math -\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big). -``` +See also: [`∘`](@ref) """ struct TransformedKernel{Tk<:Kernel,Tr<:Transform} <: Kernel kernel::Tk @@ -42,30 +36,37 @@ end _scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y)) """ - transform(k::Kernel, t::Transform) + kernel ∘ transform + ∘(kernel, transform) + compose(kernel, transform) -Create a [`TransformedKernel`](@ref) for kernel `k` and transform `t`. -""" -transform(k::Kernel, t::Transform) = TransformedKernel(k, t) -function transform(k::TransformedKernel, t::Transform) - return TransformedKernel(k.kernel, t ∘ k.transform) -end +Compose a `kernel` with a transformation `transform` of its inputs. -""" - transform(k::Kernel, ρ::Real) +The prefix forms support chains of multiple transformations: +`∘(kernel, transform1, transform2) = kernel ∘ transform1 ∘ transform2`. -Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscale `ρ`. -""" -transform(k::Kernel, ρ::Real) = transform(k, ScaleTransform(ρ)) +# Definition -""" - transform(k::Kernel, ρ::AbstractVector) +For inputs ``x, x'``, the transformed kernel ``\\widetilde{k}`` derived from kernel ``k`` by +input transformation ``t`` is defined as +```math +\\widetilde{k}(x, x'; k, t) = k\\big(t(x), t(x')\\big). +``` -Create a [`TransformedKernel`](@ref) for kernel `k` and inverse lengthscales `ρ`. -""" -transform(k::Kernel, ρ::AbstractVector) = transform(k, ARDTransform(ρ)) +# Examples + +```jldoctest +julia> (SqExponentialKernel() ∘ ScaleTransform(0.5))(0, 2) == exp(-0.5) +true -kernel(κ) = κ.kernel +julia> ∘(ExponentialKernel(), ScaleTransform(2), ScaleTransform(0.5))(1, 2) == exp(-1) +true +``` + +See also: [`TransformedKernel`](@ref) +""" +Base.:∘(k::Kernel, t::Transform) = TransformedKernel(k, t) +Base.:∘(k::TransformedKernel, t::Transform) = TransformedKernel(k.kernel, k.transform ∘ t) Base.show(io::IO, κ::TransformedKernel) = printshifted(io, κ, 0) @@ -87,13 +88,13 @@ function kernelmatrix_diag!( end function kernelmatrix!(K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector) - return kernelmatrix!(K, kernel(κ), _map(κ.transform, x)) + return kernelmatrix!(K, κ.kernel, _map(κ.transform, x)) end function kernelmatrix!( K::AbstractMatrix, κ::TransformedKernel, x::AbstractVector, y::AbstractVector ) - return kernelmatrix!(K, kernel(κ), _map(κ.transform, x), _map(κ.transform, y)) + return kernelmatrix!(K, κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) end function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector) @@ -105,9 +106,9 @@ function kernelmatrix_diag(κ::TransformedKernel, x::AbstractVector, y::Abstract end function kernelmatrix(κ::TransformedKernel, x::AbstractVector) - return kernelmatrix(kernel(κ), _map(κ.transform, x)) + return kernelmatrix(κ.kernel, _map(κ.transform, x)) end function kernelmatrix(κ::TransformedKernel, x::AbstractVector, y::AbstractVector) - return kernelmatrix(kernel(κ), _map(κ.transform, x), _map(κ.transform, y)) + return kernelmatrix(κ.kernel, _map(κ.transform, x), _map(κ.transform, y)) end diff --git a/test/basekernels/gabor.jl b/test/basekernels/gabor.jl index 472327d87..51a315ab5 100644 --- a/test/basekernels/gabor.jl +++ b/test/basekernels/gabor.jl @@ -10,14 +10,14 @@ k_manual = exp(-sqeuclidean(v1, v2) / (2 * k.ell^2)) * cospi(euclidean(v1, v2) / k.p) @test k(v1, v2) ≈ k_manual atol = 1e-5 - lhs_manual = transform(SqExponentialKernel(), 1 / k.ell)(v1, v2) - rhs_manual = transform(CosineKernel(), 1 / k.p)(v1, v2) + lhs_manual = (SqExponentialKernel() ∘ ScaleTransform(1 / k.ell))(v1, v2) + rhs_manual = (CosineKernel() ∘ ScaleTransform(1 / k.p))(v1, v2) @test k(v1, v2) ≈ lhs_manual * rhs_manual atol = 1e-5 k = GaborKernel() @test k.ell ≈ 1.0 atol = 1e-5 @test k.p ≈ 1.0 atol = 1e-5 - @test repr(k) == "Gabor Kernel (ell = 1.0, p = 1.0)" + @test repr(k) == "Gabor Kernel (ell = 1, p = 1)" test_interface(k, Vector{Float64}) diff --git a/test/deprecations.jl b/test/deprecations.jl new file mode 100644 index 000000000..b8cc0bced --- /dev/null +++ b/test/deprecations.jl @@ -0,0 +1,20 @@ +@testset "deprecations.jl" begin + p = rand() + v = rand(3) + M = rand(3, 3) + v1 = rand(3) + v2 = rand(3) + kernel = SqExponentialKernel() + + k1 = @test_deprecated transform(kernel, LinearTransform(M)) + @test k1(v1, v2) == (kernel ∘ LinearTransform(M))(v1, v2) + + k2 = @test_deprecated transform(kernel ∘ ScaleTransform(p), ARDTransform(v)) + @test k2(v1, v2) == (kernel ∘ ARDTransform(v) ∘ ScaleTransform(p))(v1, v2) + + k3 = @test_deprecated transform(kernel, p) + @test k3(v1, v2) == (kernel ∘ ScaleTransform(p))(v1, v2) + + k4 = @test_deprecated transform(kernel, v) + @test k4(v1, v2) == (kernel ∘ ARDTransform(v))(v1, v2) +end diff --git a/test/kernels/transformedkernel.jl b/test/kernels/transformedkernel.jl index f499e561d..3bbb06a62 100644 --- a/test/kernels/transformedkernel.jl +++ b/test/kernels/transformedkernel.jl @@ -5,26 +5,27 @@ v2 = rand(rng, 3) s = rand(rng) - s2 = rand(rng) v = rand(rng, 3) + P = rand(rng, 3, 2) k = SqExponentialKernel() kt = TransformedKernel(k, ScaleTransform(s)) ktard = TransformedKernel(k, ARDTransform(v)) - @test kt(v1, v2) == transform(k, ScaleTransform(s))(v1, v2) - @test kt(v1, v2) == transform(k, s)(v1, v2) + @test kt(v1, v2) == (k ∘ ScaleTransform(s))(v1, v2) @test kt(v1, v2) ≈ k(s * v1, s * v2) atol = 1e-5 - @test ktard(v1, v2) ≈ transform(k, ARDTransform(v))(v1, v2) atol = 1e-5 - @test ktard(v1, v2) == transform(k, v)(v1, v2) + @test ktard(v1, v2) == (k ∘ ARDTransform(v))(v1, v2) @test ktard(v1, v2) == k(v .* v1, v .* v2) - @test transform(kt, s2)(v1, v2) ≈ kt(s2 * v1, s2 * v2) - @test KernelFunctions.kernel(kt) == k + @test (k ∘ LinearTransform(P') ∘ ScaleTransform(s))(v1, v2) == + ((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2) == + (k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2) + @test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s)) TestUtils.test_interface(k, Float64) - test_ADs(x -> transform(SqExponentialKernel(), x[1]), rand(1))# ADs = [:ForwardDiff, :ReverseDiff]) + test_ADs(x -> SqExponentialKernel() ∘ ScaleTransform(x[1]), rand(1)) + # Test implicit gradients @testset "Implicit gradients" begin - k = transform(SqExponentialKernel(), 2.0) + k = SqExponentialKernel() ∘ ScaleTransform(2.0) ps = Flux.params(k) X = rand(10, 1) x = vec(X) @@ -46,12 +47,14 @@ @test g1[first(ps)] ≈ g3[first(ps)] end - P = rand(3, 2) - c = Chain(Dense(3, 2)) + @testset "Parameters" begin + k = ConstantKernel(; c=rand(rng)) + c = Chain(Dense(3, 2)) - test_params(transform(k, s), (k, [s])) - test_params(transform(k, v), (k, v)) - test_params(transform(k, LinearTransform(P)), (k, P)) - test_params(transform(k, LinearTransform(P) ∘ ScaleTransform(s)), (k, [s], P)) - test_params(transform(k, FunctionTransform(c)), (k, c)) + test_params(k ∘ ScaleTransform(s), (k, [s])) + test_params(k ∘ ARDTransform(v), (k, v)) + test_params(k ∘ LinearTransform(P), (k, P)) + test_params(k ∘ LinearTransform(P) ∘ ScaleTransform(s), (k, [s], P)) + test_params(k ∘ FunctionTransform(c), (k, c)) + end end diff --git a/test/runtests.jl b/test/runtests.jl index d092b3901..9bd1cbfe9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -148,6 +148,8 @@ include("test_utils.jl") include("chainrules.jl") include("zygoterules.jl") + include("deprecations.jl") + @testset "doctests" begin DocMeta.setdocmeta!( KernelFunctions, diff --git a/test/transform/ardtransform.jl b/test/transform/ardtransform.jl index 4c0df4deb..26c9969c4 100644 --- a/test/transform/ardtransform.jl +++ b/test/transform/ardtransform.jl @@ -42,5 +42,5 @@ @test_throws DimensionMismatch map(t, ColVecs(randn(rng, D + 1, 3))) @test repr(t) == "ARD Transform (dims: $D)" - test_ADs(x -> transform(SEKernel(), exp.(x)), randn(rng, 3)) + test_ADs(x -> SEKernel() ∘ ARDTransform(exp.(x)), randn(rng, 3)) end diff --git a/test/transform/chaintransform.jl b/test/transform/chaintransform.jl index dc41cc7bb..6e3d8a44f 100644 --- a/test/transform/chaintransform.jl +++ b/test/transform/chaintransform.jl @@ -23,7 +23,7 @@ # Verify printing works as expected. @test repr(tp ∘ tf) == "Chain of 2 transforms:\n\t - $(tf) |> $(tp)" test_ADs( - x -> transform(SEKernel(), ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))), + x -> SEKernel() ∘ (ScaleTransform(exp(x[1])) ∘ ARDTransform(exp.(x[2:4]))), randn(rng, 4), ) end diff --git a/test/transform/functiontransform.jl b/test/transform/functiontransform.jl index 1dd36b5b5..7d9581ecc 100644 --- a/test/transform/functiontransform.jl +++ b/test/transform/functiontransform.jl @@ -28,5 +28,5 @@ @test repr(FunctionTransform(sin)) == "Function Transform: $(sin)" f(a, x) = sin.(a .* x) - test_ADs(x -> transform(SEKernel(), FunctionTransform(y -> f(x, y))), randn(rng, 3)) + test_ADs(x -> SEKernel() ∘ FunctionTransform(y -> f(x, y)), randn(rng, 3)) end diff --git a/test/transform/lineartransform.jl b/test/transform/lineartransform.jl index ea280d85d..5b22a4d2d 100644 --- a/test/transform/lineartransform.jl +++ b/test/transform/lineartransform.jl @@ -42,5 +42,5 @@ @test_throws DimensionMismatch map(t, ColVecs(randn(rng, Din + 1, Dout))) @test repr(t) == "Linear transform (size(A) = ($Dout, $Din))" - test_ADs(x -> transform(SEKernel(), LinearTransform(x)), randn(rng, 3, 3)) + test_ADs(x -> SEKernel() ∘ LinearTransform(x), randn(rng, 3, 3)) end diff --git a/test/transform/periodic_transform.jl b/test/transform/periodic_transform.jl index 1f7119ac3..324cba907 100644 --- a/test/transform/periodic_transform.jl +++ b/test/transform/periodic_transform.jl @@ -5,10 +5,10 @@ x = collect(range(0.0, 3.0 / f; length=1_000)) # Construct in the usual way. - k_eq_periodic = transform(PeriodicKernel(; r=[sqrt(0.25)]), f) + k_eq_periodic = PeriodicKernel(; r=[sqrt(0.25)]) ∘ ScaleTransform(f) # Construct using the peridic transform. - k_eq_transform = transform(SqExponentialKernel(), PeriodicTransform(f)) + k_eq_transform = SqExponentialKernel() ∘ PeriodicTransform(f) @test kernelmatrix(k_eq_periodic, x) ≈ kernelmatrix(k_eq_transform, x) # TODO - add interface_tests once #159 is merged. diff --git a/test/transform/scaletransform.jl b/test/transform/scaletransform.jl index d416e3eb6..c0445b8a4 100644 --- a/test/transform/scaletransform.jl +++ b/test/transform/scaletransform.jl @@ -19,5 +19,5 @@ @test t.s == [s2] @test isequal(ScaleTransform(s), ScaleTransform(s)) @test repr(t) == "Scale Transform (s = $(s2))" - test_ADs(x -> transform(SEKernel(), exp(x[1])), randn(rng, 1)) + test_ADs(x -> SEKernel() ∘ ScaleTransform(exp(x[1])), randn(rng, 1)) end diff --git a/test/transform/selecttransform.jl b/test/transform/selecttransform.jl index 966fecf7c..c9888443c 100644 --- a/test/transform/selecttransform.jl +++ b/test/transform/selecttransform.jl @@ -44,7 +44,7 @@ @test repr(t) == "Select Transform (dims: $(select2))" @test repr(ts) == "Select Transform (dims: $(select_symbols2))" - test_ADs(() -> transform(SEKernel(), SelectTransform([1, 2]))) + test_ADs(() -> SEKernel() ∘ SelectTransform([1, 2])) X = randn(rng, (4, 3)) A = AxisArray(X; row=[:a, :b, :c, :d], col=[:x, :y, :z]) @@ -53,10 +53,10 @@ Z = randn(rng, (2, 3)) C = AxisArray(Z; row=[:e, :f], col=[:x, :y, :z]) - tx_row = transform(SEKernel(), SelectTransform([1, 2, 4])) - ta_row = transform(SEKernel(), SelectTransform([:a, :b, :d])) - tx_col = transform(SEKernel(), SelectTransform([1, 3])) - ta_col = transform(SEKernel(), SelectTransform([:x, :z])) + tx_row = SEKernel() ∘ SelectTransform([1, 2, 4]) + ta_row = SEKernel() ∘ SelectTransform([:a, :b, :d]) + tx_col = SEKernel() ∘ SelectTransform([1, 3]) + ta_col = SEKernel() ∘ SelectTransform([:x, :z]) @test kernelmatrix(tx_row, X; obsdim=2) ≈ kernelmatrix(ta_row, A; obsdim=2) @test kernelmatrix(tx_col, X; obsdim=1) ≈ kernelmatrix(ta_col, A; obsdim=1) diff --git a/test/transform/transform.jl b/test/transform/transform.jl index 83d63ffe7..bd3ff567d 100644 --- a/test/transform/transform.jl +++ b/test/transform/transform.jl @@ -8,5 +8,5 @@ @test IdentityTransform()(x) == x @test map(IdentityTransform(), x) == x end - test_ADs(() -> transform(SEKernel(), IdentityTransform())) + test_ADs(() -> SEKernel() ∘ IdentityTransform()) end