Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extension of #269: Use \circ and compose and deprecate transform #276

Merged
merged 20 commits into from
Apr 13, 2021
Merged
2 changes: 1 addition & 1 deletion src/basekernels/gabor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ end

(κ::GaborKernel)(x, y) = κ.kernel(x, y)

_lengthscale_transform(::Nothing) = IdentitytTransform()
_lengthscale_transform(::Nothing) = IdentityTransform()
theogf marked this conversation as resolved.
Show resolved Hide resolved
_lengthscale_transform(x::Real) = ScaleTransform(inv(x))
_lengthscale_transform(x::AbstractVector) = ARDTransform(map(inv, x))

Expand Down
20 changes: 14 additions & 6 deletions test/deprecations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,19 @@
p = rand()
v = rand(3)
M = rand(3, 3)
v1 = rand(3)
v2 = rand(3)
kernel = SqExponentialKernel()
@test (@test_deprecated transform(kernel, LinearTransform(M))) ==
kernel ∘ LinearTransform(M)
@test (@test_deprecated transform(kernel ∘ ScaleTransform(p), ARDTransform(v))) ==
kernel ∘ ARDTransform(v) ∘ ScaleTransform(p)
@test (@test_deprecated transform(kernel, p)) == kernel ∘ ScaleTransform(p)
@test (@test_deprecated transform(kernel, v)) == kernel ∘ ARDTransform(v)

k1 = @test_deprecated transform(kernel, LinearTransform(M))
@test k1(v1, v2) == (kernel ∘ LinearTransform(M))(v1, v2)

k2 = @test_deprecated transform(kernel ∘ ScaleTransform(p), ARDTransform(v))
@test k2(v1, v2) == (kernel ∘ ARDTransform(v) ∘ ScaleTransform(p))(v1, v2)

k3 = @test_deprecated transform(kernel, p)
@test k3(v1, v2) == (kernel ∘ ScaleTransform(p))(v1, v2)

k4 = @test_deprecated transform(kernel, v)
@test k4(v1, v2) == (kernel ∘ ARDTransform(v))(v1, v2)
end
18 changes: 9 additions & 9 deletions test/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,17 @@

s = rand(rng)
v = rand(rng, 3)
P = rand(rng, 3, 2)
k = SqExponentialKernel()
kt = TransformedKernel(k, ScaleTransform(s))
ktard = TransformedKernel(k, ARDTransform(v))
@test kt(v1, v2) == (k ∘ ScaleTransform(s))(v1, v2)
@test kt(v1, v2) ≈ k(s * v1, s * v2) atol = 1e-5
@test ktard(v1, v2) == (k ∘ ARDTransform(v))(v1, v2)
@test ktard(v1, v2) == k(v .* v1, v .* v2)
@test (k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2) ==
((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

@test repr(kt) == repr(k) * "\n\t- " * repr(ScaleTransform(s))

TestUtils.test_interface(k, Float64)
Expand Down Expand Up @@ -42,15 +46,11 @@
@test g1[first(ps)] ≈ g3[first(ps)]
end

P = rand(3, 2)
c = Chain(Dense(3, 2))

test_params(k ∘ ScaleTransform(s), (k, [s]))
test_params(k ∘ ARDTransform(v), (k, v))
test_params(k ∘ LinearTransform(P), (k, P))
test_params(k ∘ (LinearTransform(P) ∘ ScaleTransform(s)), (k, [s], P))
test_params(k ∘ FunctionTransform(c), (k, c))

@test (k ∘ (LinearTransform(P') ∘ ScaleTransform(s)))(v1, v2) ==
((k ∘ LinearTransform(P')) ∘ ScaleTransform(s))(v1, v2)
test_params((k ∘ ScaleTransform(s)), (k, [s]))
test_params((k ∘ ARDTransform(v)), (k, v))
test_params((k ∘ LinearTransform(P), (k, P))
test_params((k ∘ (LinearTransform(P) ∘ ScaleTransform(s))), (k, [s], P))
test_params((k ∘ FunctionTransform(c)), (k, c))
end