Skip to content

Commit

Permalink
Figure out how to avoid bad gradients
Browse files Browse the repository at this point in the history
  • Loading branch information
willtebbutt committed Sep 27, 2024
1 parent 5bae510 commit 84f1d3f
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 9 deletions.
24 changes: 24 additions & 0 deletions src/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,8 @@ function stationary_distribution(k::SimpleKernel, ::ArrayStorage{T}) where {T<:R
return Gaussian(collect(x.m), collect(x.P))
end

safe_to_product(::Kernel) = false

# Matern-1/2

function to_sde(::Matern12Kernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -203,6 +205,8 @@ function stationary_distribution(::Matern12Kernel, ::SArrayStorage{T}) where {T<
)
end

safe_to_product(::Matern12Kernel) = true

# Matern - 3/2

function to_sde(::Matern32Kernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -220,6 +224,8 @@ function stationary_distribution(::Matern32Kernel, ::SArrayStorage{T}) where {T<
)
end

safe_to_product(::Matern32Kernel) = true

# Matern - 5/2

function to_sde(::Matern52Kernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -237,6 +243,8 @@ function stationary_distribution(::Matern52Kernel, ::SArrayStorage{T}) where {T<
return Gaussian(m, P)
end

safe_to_product(::Matern52Kernel) = true

# Cosine

function to_sde(::CosineKernel, ::SArrayStorage{T}) where {T}
Expand All @@ -252,6 +260,8 @@ function stationary_distribution(::CosineKernel, ::SArrayStorage{T}) where {T<:R
return Gaussian(m, P)
end

safe_to_product(::CosineKernel) = true

# ApproxPeriodicKernel

# The periodic kernel is approximated by a sum of cosine kernels with different frequencies.
Expand Down Expand Up @@ -309,6 +319,8 @@ function stationary_distribution(kernel::ApproxPeriodicKernel{N}, storage::Array
return Gaussian(m, P)
end

safe_to_product(::ApproxPeriodicKernel) = true

# Constant

function TemporalGPs.to_sde(::ConstantKernel, ::SArrayStorage{T}) where {T<:Real}
Expand All @@ -322,6 +334,9 @@ function TemporalGPs.stationary_distribution(k::ConstantKernel, ::SArrayStorage{
return TemporalGPs.Gaussian(SVector{1, T}(0), SMatrix{1, 1, T}(T(only(k.c))))
end

safe_to_product(::ConstantKernel) = true


# Scaled

function to_sde(k::ScaledKernel, storage::StorageType{T}) where {T<:Real}
Expand All @@ -334,6 +349,8 @@ function stationary_distribution(k::ScaledKernel, storage::StorageType)
return stationary_distribution(k.kernel, storage)
end

safe_to_product(k::ScaledKernel) = safe_to_product(k.kernel)

function lgssm_components(k::ScaledKernel, ts::AbstractVector, storage_type::StorageType)
As, as, Qs, emission_proj, x0 = lgssm_components(k.kernel, ts, storage_type)
σ = sqrt(convert(eltype(storage_type), only(k.σ²)))
Expand Down Expand Up @@ -361,6 +378,8 @@ function stationary_distribution(
return stationary_distribution(k.kernel, storage)
end

safe_to_product(::TransformedKernel{<:Kernel, <:ScaleTransform}) = false

function lgssm_components(
k::TransformedKernel{<:Kernel, <:ScaleTransform},
ts::AbstractVector,
Expand All @@ -377,7 +396,12 @@ apply_stretch(a, ts::RegularSpacing) = RegularSpacing(a * ts.t0, a * ts.Δt, ts.

# Product

safe_to_product(k::KernelProduct) = all(safe_to_product, k.kernels)

function lgssm_components(k::KernelProduct, ts::AbstractVector, storage::StorageType)

safe_to_product(k) || throw(ArgumentError("Not all kernels in k are safe to product."))

sde_kernels = to_sde.(k.kernels, Ref(storage))
F_kernels = getindex.(sde_kernels, 1)
F = foldl(_kron_add, F_kernels)
Expand Down
11 changes: 11 additions & 0 deletions test/front_matter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -54,5 +54,16 @@ ENV["TESTING"] = "TRUE"
# ENV["GROUP"] = "test gp"
const GROUP = get(ENV, "GROUP", "all")

# Some test-local type piracy. ConstantKernel doesn't have a default constructor, so
# Mooncake's testing functionality doesn't work with it properly. To resolve this, I just
# add a default-style constructor here.
@eval function KernelFunctions.ConstantKernel{P}(c::Vector{P}) where {P<:Real}
$(Expr(:new, :(ConstantKernel{P}), :c))
end

@eval function PeriodicKernel{P}(c::Vector{P}) where {P<:Real}
$(Expr(:new, :(PeriodicKernel{P}), :c))
end

include("test_util.jl")
include(joinpath("models", "model_test_utils.jl"))
16 changes: 7 additions & 9 deletions test/gp/lti_sde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@ using Test

# Everything is tested once the LGSSM is constructed, so it is sufficient just to ensure
# that Zygote can handle construction.
function _construction_tester(f_naive::GP, storage::StorageType, σ², t::AbstractVector)
function _logpdf_tester(f_naive::GP, y, storage::StorageType, σ², t::AbstractVector)
f = to_sde(f_naive, storage)
fx = f(t, σ²...)
return build_lgssm(fx)
return logpdf(f(t, σ²...), y)
end

println("lti_sde:")
Expand Down Expand Up @@ -112,15 +111,14 @@ println("lti_sde:")

# Product kernels
(
name="prod-Matern12Kernel-Matern32Kernel",
val=1.5 * Matern12Kernel() ScaleTransform(0.1) * Matern32Kernel()
ScaleTransform(1.1),
name="prod-Matern52Kernel-Matern32Kernel",
val=(1.5 * Matern52Kernel() * Matern32Kernel()) ScaleTransform(0.01),
),
(
name="prod-Matern32Kernel-Matern52Kernel-ConstantKernel",
val=3.0 * Matern32Kernel() * Matern52Kernel() * ConstantKernel(),
),
# THIS IS KNOWN NOT TO WORK!
# This is known not to work at all (not a gradient problem).
# (
# name="prod-(Matern32Kernel + ConstantKernel) * Matern52Kernel",
# val=(Matern32Kernel() + ConstantKernel()) * Matern52Kernel(),
Expand Down Expand Up @@ -203,8 +201,8 @@ println("lti_sde:")
end

test_rule(
rng, _construction_tester, f_naive, storage.val, σ².val, t.val;
is_primitive=false, interface_only=true,
rng, _logpdf_tester, f_naive, y, storage.val, σ².val, t.val;
is_primitive=false,
)
end
end
Expand Down

0 comments on commit 84f1d3f

Please sign in to comment.