Skip to content

Commit

Permalink
fix MaternKernel AD, but remove differentiation wrt \nu (#425)
Browse files Browse the repository at this point in the history
* reactivate test_ADs of MaternKernel, but do not test differentiation through \nu argument
* make clear in docstring that MaternKernel does not support derivative w.r.t. order \nu
* work-around for Zygote AD that will return nothing/NaN gradients w.r.t. \nu but fixes the gradient w.r.t. x
  • Loading branch information
st-- authored Apr 13, 2022
1 parent 873aa8d commit 8e805ef
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 8 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "KernelFunctions"
uuid = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
version = "0.10.36"
version = "0.10.37"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
11 changes: 9 additions & 2 deletions src/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ By default, ``d`` is the Euclidean metric ``d(x, x') = \\|x - x'\\|_2``.
A Gaussian process with a Matérn kernel is ``\\lceil \\nu \\rceil - 1``-times
differentiable in the mean-square sense.
!!! note
Differentiation with respect to the order ν is not currently supported.
See also: [`Matern12Kernel`](@ref), [`Matern32Kernel`](@ref), [`Matern52Kernel`](@ref)
"""
struct MaternKernel{Tν<:Real,M} <: SimpleKernel
Expand All @@ -33,8 +37,11 @@ MaternKernel(; nu::Real=1.5, ν::Real=nu, metric=Euclidean()) = MaternKernel(ν,

@functor MaternKernel

@inline function kappa::MaternKernel, d::Real)
result = _matern(only.ν), d)
@inline _get_ν(k::MaternKernel) = only(k.ν)
ChainRulesCore.@non_differentiable _get_ν(k) # work-around; should be "NotImplemented" rather than NoTangent

@inline function kappa(k::MaternKernel, d::Real)
result = _matern(_get_ν(k), d)
return ifelse(iszero(d), one(result), result)
end

Expand Down
2 changes: 1 addition & 1 deletion src/matrix/kernelkroneckermat.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ where `D` is given by `dims`.
!!! warning
Require `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
Requires `Kronecker.jl` and for `iskroncompatible(κ)` to return `true`.
"""
function kernelkronmat::Kernel, X::AbstractVector{<:Real}, dims::Int)
checkkroncompatible(κ)
Expand Down
7 changes: 3 additions & 4 deletions test/basekernels/matern.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,25 +4,24 @@
v1 = rand(rng, 3)
v2 = rand(rng, 3)
@testset "MaternKernel" begin
ν = 2.0
ν = 2.1
k = MaternKernel(; ν=ν)
matern(x, ν) = 2^(1 - ν) / gamma(ν) * (sqrt(2ν) * x)^ν * besselk(ν, sqrt(2ν) * x)
@test MaternKernel(; nu=ν).ν == [ν]
@test kappa(k, x) matern(x, ν)
@test kappa(k, 0.0) == 1.0
@test kappa(MaternKernel(; ν=ν), x) == kappa(k, x)
@test metric(MaternKernel()) == Euclidean()
@test metric(MaternKernel(; ν=2.0)) == Euclidean()
@test repr(k) == "Matern Kernel (ν = $(ν), metric = Euclidean(0.0))"
# test_ADs(x->MaternKernel(nu=first(x)),[ν])
@test_broken "All fails (because of logabsgamma for ForwardDiff and ReverseDiff and because of nu for Zygote)"

k2 = MaternKernel(; ν=ν, metric=WeightedEuclidean(ones(3)))
@test metric(k2) isa WeightedEuclidean
@test k2(v1, v2) k(v1, v2)

# Standardised tests.
TestUtils.test_interface(k, Float64)
test_ADs(() -> MaternKernel(; nu=ν))

test_params(k, ([ν],))
end
@testset "Matern32Kernel" begin
Expand Down

2 comments on commit 8e805ef

@st--
Copy link
Member Author

@st-- st-- commented on 8e805ef Apr 13, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/58440

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.10.37 -m "<description of version>" 8e805ef25d745ab391e4ad424e690507e02cd042
git push origin v0.10.37

Please sign in to comment.