Skip to content

Commit

Permalink
Rename TensorProduct and implement TensorCore.tensor (#232)
Browse files Browse the repository at this point in the history
  • Loading branch information
devmotion authored Jan 20, 2021
1 parent 0ff8761 commit 5fd5157
Show file tree
Hide file tree
Showing 17 changed files with 251 additions and 243 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.8.18"
version = "0.8.19"

[deps]
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Expand All @@ -13,6 +13,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
TensorCore = "62fd8b95-f654-4bbd-a8a5-9c27f68ccd50"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -24,5 +25,6 @@ Requires = "1.0.1"
SpecialFunctions = "0.8, 0.9, 0.10, 1"
StatsBase = "0.32, 0.33"
StatsFuns = "0.8, 0.9"
TensorCore = "0.1"
ZygoteRules = "0.2"
julia = "1.3"
2 changes: 1 addition & 1 deletion docs/src/kernels.md
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ transform(::Kernel, ::AbstractVector)
ScaledKernel
KernelSum
KernelProduct
TensorProduct
KernelTensorProduct
```

## Multi-output Kernels
Expand Down
10 changes: 7 additions & 3 deletions src/KernelFunctions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@ export LinearKernel, PolynomialKernel
export RationalQuadraticKernel, GammaRationalQuadraticKernel
export GaborKernel, PiecewisePolynomialKernel
export PeriodicKernel, NeuralNetworkKernel
export KernelSum, KernelProduct
export KernelSum, KernelProduct, KernelTensorProduct
export TransformedKernel, ScaledKernel
export TensorProduct

export Transform,
SelectTransform,
Expand All @@ -52,6 +51,9 @@ export ColVecs, RowVecs
export MOInput
export IndependentMOKernel, LatentFactorMOKernel

# Reexports
export tensor,

using Compat
using Requires
using Distances, LinearAlgebra
Expand All @@ -61,6 +63,7 @@ using ZygoteRules: @adjoint, pullback
using StatsFuns: logtwo
using InteractiveUtils: subtypes
using StatsBase
using TensorCore

abstract type Kernel end
abstract type SimpleKernel <: Kernel end
Expand Down Expand Up @@ -100,7 +103,8 @@ include(joinpath("kernels", "scaledkernel.jl"))
include(joinpath("matrix", "kernelmatrix.jl"))
include(joinpath("kernels", "kernelsum.jl"))
include(joinpath("kernels", "kernelproduct.jl"))
include(joinpath("kernels", "tensorproduct.jl"))
include(joinpath("kernels", "kerneltensorproduct.jl"))
include(joinpath("kernels", "overloads.jl"))
include(joinpath("approximations", "nystrom.jl"))
include("generic.jl")

Expand Down
2 changes: 1 addition & 1 deletion src/basekernels/sm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ function spectral_mixture_product_kernel(
if !(size(αs) == size(γs) == size(ωs))
throw(DimensionMismatch("The dimensions of αs, γs, ans ωs do not match"))
end
return TensorProduct(
return KernelTensorProduct(
spectral_mixture_kernel(h, α, reshape(γ, 1, :), reshape(ω, 1, :)) for
(α, γ, ω) in zip(eachrow(αs), eachrow(γs), eachrow(ωs))
)
Expand Down
5 changes: 5 additions & 0 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,8 @@
@deprecate PiecewisePolynomialKernel{V}(A::AbstractMatrix{<:Real}) where {V} transform(
PiecewisePolynomialKernel{V}(size(A, 1)), LinearTransform(cholesky(A).U)
)

@deprecate TensorProduct(kernels) KernelTensorProduct(kernels)
@deprecate TensorProduct(kernel::Kernel, kernels::Kernel...) KernelTensorProduct(
kernel, kernels...
)
25 changes: 0 additions & 25 deletions src/kernels/kernelproduct.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,31 +41,6 @@ end

@functor KernelProduct

Base.:*(k1::Kernel, k2::Kernel) = KernelProduct(k1, k2)

function Base.:*(
k1::KernelProduct{<:AbstractVector{<:Kernel}},
k2::KernelProduct{<:AbstractVector{<:Kernel}},
)
return KernelProduct(vcat(k1.kernels, k2.kernels))
end

function Base.:*(k1::KernelProduct, k2::KernelProduct)
return KernelProduct(k1.kernels..., k2.kernels...)
end

function Base.:*(k::Kernel, ks::KernelProduct{<:AbstractVector{<:Kernel}})
return KernelProduct(vcat(k, ks.kernels))
end

Base.:*(k::Kernel, kp::KernelProduct) = KernelProduct(k, kp.kernels...)

function Base.:*(ks::KernelProduct{<:AbstractVector{<:Kernel}}, k::Kernel)
return KernelProduct(vcat(ks.kernels, k))
end

Base.:*(kp::KernelProduct, k::Kernel) = KernelProduct(kp.kernels..., k)

Base.length(k::KernelProduct) = length(k.kernels)

::KernelProduct)(x, y) = prod(k(x, y) for k in κ.kernels)
Expand Down
22 changes: 0 additions & 22 deletions src/kernels/kernelsum.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,28 +41,6 @@ end

@functor KernelSum

Base.:+(k1::Kernel, k2::Kernel) = KernelSum(k1, k2)

function Base.:+(
k1::KernelSum{<:AbstractVector{<:Kernel}}, k2::KernelSum{<:AbstractVector{<:Kernel}}
)
return KernelSum(vcat(k1.kernels, k2.kernels))
end

Base.:+(k1::KernelSum, k2::KernelSum) = KernelSum(k1.kernels..., k2.kernels...)

function Base.:+(k::Kernel, ks::KernelSum{<:AbstractVector{<:Kernel}})
return KernelSum(vcat(k, ks.kernels))
end

Base.:+(k::Kernel, ks::KernelSum) = KernelSum(k, ks.kernels...)

function Base.:+(ks::KernelSum{<:AbstractVector{<:Kernel}}, k::Kernel)
return KernelSum(vcat(ks.kernels, k))
end

Base.:+(ks::KernelSum, k::Kernel) = KernelSum(ks.kernels..., k)

Base.length(k::KernelSum) = length(k.kernels)

::KernelSum)(x, y) = sum(k(x, y) for k in κ.kernels)
Expand Down
144 changes: 144 additions & 0 deletions src/kernels/kerneltensorproduct.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
"""
KernelTensorProduct
Tensor product of kernels.
# Definition
For inputs ``x = (x_1, \\ldots, x_n)`` and ``x' = (x'_1, \\ldots, x'_n)``, the tensor
product of kernels ``k_1, \\ldots, k_n`` is defined as
```math
k(x, x'; k_1, \\ldots, k_n) = \\Big(\\bigotimes_{i=1}^n k_i\\Big)(x, x') = \\prod_{i=1}^n k_i(x_i, x'_i).
```
# Construction
The simplest way to specify a `KernelTensorProduct` is to use the overloaded `tensor`
operator or its alias `⊗` (can be typed by `\\otimes<tab>`).
```jldoctest tensorproduct
julia> k1 = SqExponentialKernel(); k2 = LinearKernel(); X = rand(5, 2);
julia> kernelmatrix(k1 ⊗ k2, RowVecs(X)) == kernelmatrix(k1, X[:, 1]) .* kernelmatrix(k2, X[:, 2])
true
```
You can also specify a `KernelTensorProduct` by providing kernels as individual arguments
or as an iterable data structure such as a `Tuple` or a `Vector`. Using a tuple or
individual arguments guarantees that `KernelTensorProduct` is concretely typed but might
lead to large compilation times if the number of kernels is large.
```jldoctest tensorproduct
julia> KernelTensorProduct(k1, k2) == k1 ⊗ k2
true
julia> KernelTensorProduct((k1, k2)) == k1 ⊗ k2
true
julia> KernelTensorProduct([k1, k2]) == k1 ⊗ k2
true
```
"""
struct KernelTensorProduct{K} <: Kernel
kernels::K
end

function KernelTensorProduct(kernel::Kernel, kernels::Kernel...)
return KernelTensorProduct((kernel, kernels...))
end

@functor KernelTensorProduct

Base.length(kernel::KernelTensorProduct) = length(kernel.kernels)

function (kernel::KernelTensorProduct)(x, y)
if !(length(x) == length(y) == length(kernel))
throw(DimensionMismatch("number of kernels and number of features
are not consistent"))
end
return prod(k(xi, yi) for (k, xi, yi) in zip(kernel.kernels, x, y))
end

function validate_domain(k::KernelTensorProduct, x::AbstractVector)
return dim(x) == length(k) ||
error("number of kernels and groups of features are not consistent")
end

# Utility for slicing up inputs.
slices(x::AbstractVector{<:Real}) = (x,)
slices(x::ColVecs) = eachrow(x.X)
slices(x::RowVecs) = eachcol(x.X)

function kernelmatrix!(K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector)
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi)
end

return K
end

function kernelmatrix!(
K::AbstractMatrix, k::KernelTensorProduct, x::AbstractVector, y::AbstractVector
)
validate_inplace_dims(K, x, y)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x), slices(y))
kernelmatrix!(K, first(kernels_and_inputs)...)
for (k, xi, yi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kernelmatrix(k, xi, yi)
end

return K
end

function kerneldiagmatrix!(K::AbstractVector, k::KernelTensorProduct, x::AbstractVector)
validate_inplace_dims(K, x)
validate_domain(k, x)

kernels_and_inputs = zip(k.kernels, slices(x))
kerneldiagmatrix!(K, first(kernels_and_inputs)...)
for (k, xi) in Iterators.drop(kernels_and_inputs, 1)
K .*= kerneldiagmatrix(k, xi)
end

return K
end

function kernelmatrix(k::KernelTensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x))
end

function kernelmatrix(k::KernelTensorProduct, x::AbstractVector, y::AbstractVector)
validate_domain(k, x)
return mapreduce(kernelmatrix, hadamard, k.kernels, slices(x), slices(y))
end

function kerneldiagmatrix(k::KernelTensorProduct, x::AbstractVector)
validate_domain(k, x)
return mapreduce(kerneldiagmatrix, hadamard, k.kernels, slices(x))
end

Base.show(io::IO, kernel::KernelTensorProduct) = printshifted(io, kernel, 0)

function Base.:(==)(x::KernelTensorProduct, y::KernelTensorProduct)
return (
length(x.kernels) == length(y.kernels) &&
all(kx == ky for (kx, ky) in zip(x.kernels, y.kernels))
)
end

function printshifted(io::IO, kernel::KernelTensorProduct, shift::Int)
print(io, "Tensor product of ", length(kernel), " kernels:")
for k in kernel.kernels
print(io, "\n")
for _ in 1:(shift + 1)
print(io, "\t")
end
printshifted(io, k, shift + 2)
end
end
22 changes: 22 additions & 0 deletions src/kernels/overloads.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
for (M, op, T) in (
(:Base, :+, :KernelSum),
(:Base, :*, :KernelProduct),
(:TensorCore, :tensor, :KernelTensorProduct),
)
@eval begin
$M.$op(k1::Kernel, k2::Kernel) = $T(k1, k2)

$M.$op(k1::$T, k2::$T) = $T(k1.kernels..., k2.kernels...)
function $M.$op(
k1::$T{<:AbstractVector{<:Kernel}}, k2::$T{<:AbstractVector{<:Kernel}}
)
return $T(vcat(k1.kernels, k2.kernels))
end

$M.$op(k::Kernel, ks::$T) = $T(k, ks.kernels...)
$M.$op(k::Kernel, ks::$T{<:AbstractVector{<:Kernel}}) = $T(vcat(k, ks.kernels))

$M.$op(ks::$T, k::Kernel) = $T(ks.kernels..., k)
$M.$op(ks::$T{<:AbstractVector{<:Kernel}}, k::Kernel) = $T(vcat(ks.kernels, k))
end
end
Loading

2 comments on commit 5fd5157

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/28365

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.8.19 -m "<description of version>" 5fd5157c54130a20b89355151b69d5da7f2ae3d6
git push origin v0.8.19

Please sign in to comment.