Skip to content

Commit

Permalink
use only() instead of first() (#403)
Browse files Browse the repository at this point in the history
* use only() instead of first() for 1-"vectors" that were for the benefit of Flux

* fix one test that should not have worked as it was

* add missing scalar Sinus constructor
  • Loading branch information
st-- authored Dec 20, 2021
1 parent 2d17212 commit 05fe340
Show file tree
Hide file tree
Showing 18 changed files with 48 additions and 43 deletions.
4 changes: 2 additions & 2 deletions src/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,8 +73,8 @@ end

@functor ConstantKernel

kappa::ConstantKernel, x::Real) = first.c) * one(x)
kappa::ConstantKernel, x::Real) = only.c) * one(x)

metric(::ConstantKernel) = Delta()

Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", first.c), ")")
Base.show(io::IO, κ::ConstantKernel) = print(io, "Constant Kernel (c = ", only.c), ")")
4 changes: 2 additions & 2 deletions src/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -137,14 +137,14 @@ end

@functor GammaExponentialKernel

kappa::GammaExponentialKernel, d::Real) = exp(-d^first.γ))
kappa::GammaExponentialKernel, d::Real) = exp(-d^only.γ))

metric(k::GammaExponentialKernel) = k.metric

iskroncompatible(::GammaExponentialKernel) = true

function Base.show(io::IO, κ::GammaExponentialKernel)
return print(
io, "Gamma Exponential Kernel (γ = ", first.γ), ", metric = ", κ.metric, ")"
io, "Gamma Exponential Kernel (γ = ", only.γ), ", metric = ", κ.metric, ")"
)
end
6 changes: 3 additions & 3 deletions src/basekernels/fbm.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,16 @@ function (κ::FBMKernel)(x::AbstractVector{<:Real}, y::AbstractVector{<:Real})
modX = sum(abs2, x)
modY = sum(abs2, y)
modXY = sqeuclidean(x, y)
h = first.h)
h = only.h)
return (modX^h + modY^h - modXY^h) / 2
end

function::FBMKernel)(x::Real, y::Real)
return (abs2(x)^first.h) + abs2(y)^first.h) - abs2(x - y)^first.h)) / 2
return (abs2(x)^only.h) + abs2(y)^only.h) - abs2(x - y)^only.h)) / 2
end

function Base.show(io::IO, κ::FBMKernel)
return print(io, "Fractional Brownian Motion Kernel (h = ", first.h), ")")
return print(io, "Fractional Brownian Motion Kernel (h = ", only.h), ")")
end

_fbm(modX, modY, modXY, h) = (modX^h + modY^h - modXY^h) / 2
Expand Down
4 changes: 2 additions & 2 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,
@functor MaternKernel

@inline function kappa::MaternKernel, d::Real)
result = _matern(first.ν), d)
result = _matern(only.ν), d)
return ifelse(iszero(d), one(result), result)
end

Expand All @@ -46,7 +46,7 @@ end
metric(k::MaternKernel) = k.metric

function Base.show(io::IO, κ::MaternKernel)
return print(io, "Matern Kernel (ν = ", first.ν), ", metric = ", κ.metric, ")")
return print(io, "Matern Kernel (ν = ", only.ν), ", metric = ", κ.metric, ")")
end

## Matern12Kernel = ExponentialKernel aliased in exponential.jl
Expand Down
10 changes: 5 additions & 5 deletions src/basekernels/polynomial.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@ LinearKernel(; c::Real=0.0) = LinearKernel(c)

@functor LinearKernel

kappa::LinearKernel, xᵀy::Real) = xᵀy + first.c)
kappa::LinearKernel, xᵀy::Real) = xᵀy + only.c)

metric(::LinearKernel) = DotProduct()

Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", first.c), ")")
Base.show(io::IO, κ::LinearKernel) = print(io, "Linear Kernel (c = ", only.c), ")")

"""
PolynomialKernel(; degree::Int=2, c::Real=0.0)
Expand All @@ -53,7 +53,7 @@ struct PolynomialKernel{Tc<:Real} <: SimpleKernel

function PolynomialKernel{Tc}(degree::Int, c::Vector{Tc}) where {Tc}
@check_args(PolynomialKernel, degree, degree >= one(degree), "degree ≥ 1")
@check_args(PolynomialKernel, c, first(c) >= zero(Tc), "c ≥ 0")
@check_args(PolynomialKernel, c, only(c) >= zero(Tc), "c ≥ 0")
return new{Tc}(degree, c)
end
end
Expand All @@ -68,10 +68,10 @@ function Functors.functor(::Type{<:PolynomialKernel}, x)
return (c=x.c,), reconstruct_polynomialkernel
end

kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + first.c))^κ.degree
kappa::PolynomialKernel, xᵀy::Real) = (xᵀy + only.c))^κ.degree

metric(::PolynomialKernel) = DotProduct()

function Base.show(io::IO, κ::PolynomialKernel)
return print(io, "Polynomial Kernel (c = ", first.c), ", degree = ", κ.degree, ")")
return print(io, "Polynomial Kernel (c = ", only.c), ", degree = ", κ.degree, ")")
end
16 changes: 8 additions & 8 deletions src/basekernels/rational.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@ end
@functor RationalKernel

function kappa::RationalKernel, d::Real)
return (one(d) + d / first.α))^(-first.α))
return (one(d) + d / only.α))^(-only.α))
end

metric(k::RationalKernel) = k.metric

function Base.show(io::IO, κ::RationalKernel)
return print(io, "Rational Kernel (α = ", first.α), ", metric = ", κ.metric, ")")
return print(io, "Rational Kernel (α = ", only.α), ", metric = ", κ.metric, ")")
end

"""
Expand Down Expand Up @@ -72,18 +72,18 @@ end
@functor RationalQuadraticKernel

function kappa::RationalQuadraticKernel, d::Real)
return (one(d) + d^2 / (2 * first.α)))^(-first.α))
return (one(d) + d^2 / (2 * only.α)))^(-only.α))
end
function kappa::RationalQuadraticKernel{<:Real,<:Euclidean}, d²::Real)
return (one(d²) +/ (2 * first.α)))^(-first.α))
return (one(d²) +/ (2 * only.α)))^(-only.α))
end

metric(k::RationalQuadraticKernel) = k.metric
metric(::RationalQuadraticKernel{<:Real,<:Euclidean}) = SqEuclidean()

function Base.show(io::IO, κ::RationalQuadraticKernel)
return print(
io, "Rational Quadratic Kernel (α = ", first.α), ", metric = ", κ.metric, ")"
io, "Rational Quadratic Kernel (α = ", only.α), ", metric = ", κ.metric, ")"
)
end

Expand Down Expand Up @@ -122,7 +122,7 @@ end
@functor GammaRationalKernel

function kappa::GammaRationalKernel, d::Real)
return (one(d) + d^first.γ) / first.α))^(-first.α))
return (one(d) + d^only.γ) / only.α))^(-only.α))
end

metric(k::GammaRationalKernel) = k.metric
Expand All @@ -131,9 +131,9 @@ function Base.show(io::IO, κ::GammaRationalKernel)
return print(
io,
"Gamma Rational Kernel (α = ",
first.α),
only.α),
", γ = ",
first.γ),
only.γ),
", metric = ",
κ.metric,
")",
Expand Down
4 changes: 3 additions & 1 deletion src/distances/sinus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ struct Sinus{T} <: Distances.UnionSemiMetric
r::Vector{T}
end

Sinus(r::Real) = Sinus([r])

Distances.parameters(d::Sinus) = d.r
@inline Distances.eval_op(::Sinus, a::Real, b::Real, p::Real) = abs2(sinpi(a - b) / p)
@inline (dist::Sinus)(a::AbstractArray, b::AbstractArray) = Distances._evaluate(dist, a, b)
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / first(dist.r))
@inline (dist::Sinus)(a::Number, b::Number) = abs2(sinpi(a - b) / only(dist.r))

Distances.result_type(::Sinus{T}, Ta::Type, Tb::Type) where {T} = promote_type(T, Ta, Tb)

Expand Down
4 changes: 2 additions & 2 deletions src/kernels/scaledkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ end

@functor ScaledKernel

(k::ScaledKernel)(x, y) = first(k.σ²) * k.kernel(x, y)
(k::ScaledKernel)(x, y) = only(k.σ²) * k.kernel(x, y)

function kernelmatrix::ScaledKernel, x::AbstractVector, y::AbstractVector)
return κ.σ² .* kernelmatrix.kernel, x, y)
Expand Down Expand Up @@ -75,5 +75,5 @@ Base.show(io::IO, κ::ScaledKernel) = printshifted(io, κ, 0)

function printshifted(io::IO, κ::ScaledKernel, shift::Int)
printshifted(io, κ.kernel, shift)
return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(first.σ²))")
return print(io, "\n" * ("\t"^(shift + 1)) * "- σ² = $(only.σ²))")
end
4 changes: 2 additions & 2 deletions src/kernels/transformedkernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,10 @@ function (k::TransformedKernel{<:SimpleKernel,<:ScaleTransform})(
end

function _scale(t::ScaleTransform, metric::Euclidean, x, y)
return first(t.s) * evaluate(metric, x, y)
return only(t.s) * evaluate(metric, x, y)
end
function _scale(t::ScaleTransform, metric::Union{SqEuclidean,DotProduct}, x, y)
return first(t.s)^2 * evaluate(metric, x, y)
return only(t.s)^2 * evaluate(metric, x, y)
end
_scale(t::ScaleTransform, metric, x, y) = evaluate(metric, t(x), t(y))

Expand Down
2 changes: 1 addition & 1 deletion src/transform/ardtransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ end

dim(t::ARDTransform) = length(t.v)

(t::ARDTransform)(x::Real) = first(t.v) * x
(t::ARDTransform)(x::Real) = only(t.v) * x
(t::ARDTransform)(x) = t.v .* x

_map(t::ARDTransform, x::AbstractVector{<:Real}) = t.v' .* x
Expand Down
8 changes: 4 additions & 4 deletions src/transform/periodic_transform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@ PeriodicTransform(f::Real) = PeriodicTransform([f])

dim(t::PeriodicTransform) = 2

(t::PeriodicTransform)(x::Real) = [sinpi(2 * first(t.f) * x), cospi(2 * first(t.f) * x)]
(t::PeriodicTransform)(x::Real) = [sinpi(2 * only(t.f) * x), cospi(2 * only(t.f) * x)]

function _map(t::PeriodicTransform, x::AbstractVector{<:Real})
return RowVecs(hcat(sinpi.((2 * first(t.f)) .* x), cospi.((2 * first(t.f)) .* x)))
return RowVecs(hcat(sinpi.((2 * only(t.f)) .* x), cospi.((2 * only(t.f)) .* x)))
end

function Base.isequal(t1::PeriodicTransform, t2::PeriodicTransform)
return isequal(first(t1.f), first(t2.f))
return isequal(only(t1.f), only(t2.f))
end

function Base.show(io::IO, t::PeriodicTransform)
return print(io, "Periodic Transform with frequency $(first(t.f))")
return print(io, "Periodic Transform with frequency $(only(t.f))")
end
12 changes: 6 additions & 6 deletions src/transform/scaletransform.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ end

set!(t::ScaleTransform, ρ::Real) = t.s .= [ρ]

(t::ScaleTransform)(x) = first(t.s) * x
(t::ScaleTransform)(x) = only(t.s) * x

_map(t::ScaleTransform, x::AbstractVector{<:Real}) = first(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(first(t.s) .* x.X)
_map(t::ScaleTransform, x::AbstractVector{<:Real}) = only(t.s) .* x
_map(t::ScaleTransform, x::ColVecs) = ColVecs(only(t.s) .* x.X)
_map(t::ScaleTransform, x::RowVecs) = RowVecs(only(t.s) .* x.X)

Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(first(t.s), first(t2.s))
Base.isequal(t::ScaleTransform, t2::ScaleTransform) = isequal(only(t.s), only(t2.s))

Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", first(t.s), ")")
Base.show(io::IO, t::ScaleTransform) = print(io, "Scale Transform (s = ", only(t.s), ")")
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
[deps]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
Distances = "b4f34e82-e78d-54a5-968a-f98e89d6e8f7"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
Expand Down
2 changes: 1 addition & 1 deletion test/basekernels/constant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,6 @@

# Standardised tests.
TestUtils.test_interface(k, Float64)
test_ADs(c -> ConstantKernel(; c=first(c)), [c])
test_ADs(c -> ConstantKernel(; c=only(c)), [c])
end
end
2 changes: 1 addition & 1 deletion test/basekernels/exponential.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
@test metric(k2) isa WeightedEuclidean
@test k2(v1, v2) k(v1, v2)

test_ADs-> GammaExponentialKernel(; gamma=first(γ)), [1 + 0.5 * rand()])
test_ADs-> GammaExponentialKernel(; gamma=only(γ)), [1 + 0.5 * rand()])
test_params(k, ([γ],))
TestUtils.test_interface(GammaExponentialKernel(; γ=1.36))

Expand Down
3 changes: 2 additions & 1 deletion test/distances/sinus.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@
d = KernelFunctions.Sinus(p)
@test Distances.parameters(d) == p
@test evaluate(d, A, B) == sum(abs2.(sinpi.(A - B) ./ p))
@test d(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
d1 = KernelFunctions.Sinus(first(p))
@test d1(3.0, 2.0) == abs2(sinpi(3.0 - 2.0) / first(p))
end
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ using Zygote: Zygote
using ForwardDiff: ForwardDiff
using ReverseDiff: ReverseDiff
using FiniteDifferences: FiniteDifferences
using Compat: only

using KernelFunctions: SimpleKernel, metric, kappa, ColVecs, RowVecs, TestUtils

Expand Down
4 changes: 2 additions & 2 deletions test/test_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ const FDM = FiniteDifferences.central_fdm(5, 1)
gradient(f, s::Symbol, args) = gradient(f, Val(s), args)

function gradient(f, ::Val{:Zygote}, args)
g = first(Zygote.gradient(f, args))
g = only(Zygote.gradient(f, args))
if isnothing(g)
if args isa AbstractArray{<:Real}
return zeros(size(args)) # To respect the same output as other ADs
Expand All @@ -66,7 +66,7 @@ function gradient(f, ::Val{:ReverseDiff}, args)
end

function gradient(f, ::Val{:FiniteDiff}, args)
return first(FiniteDifferences.grad(FDM, f, args))
return only(FiniteDifferences.grad(FDM, f, args))
end

function compare_gradient(f, ::Val{:FiniteDiff}, args)
Expand Down

0 comments on commit 05fe340

Please sign in to comment.