Skip to content

Example notebook: deep kernel learning #322

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions examples/deep-kernel-learning/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@ git-tree-sha1 = "6f1d9bc1c08f9f4a8fa92e3ea3cb50153a1b40d4"
uuid = "621f4979-c628-5d54-868e-fcf4e3e8185c"
version = "1.1.0"

[[deps.AbstractGPs]]
deps = ["ChainRulesCore", "Distributions", "FillArrays", "KernelFunctions", "LinearAlgebra", "Random", "RecipesBase", "Reexport", "Statistics", "StatsBase"]
git-tree-sha1 = "d8b6584ff1d523dd1304671f2c8a557dad26e214"
uuid = "99985d1d-32ba-4be9-9821-2ec096f28918"
version = "0.3.6"

[[deps.AbstractTrees]]
git-tree-sha1 = "03e0550477d86222521d254b741d470ba17ea0b5"
uuid = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
Expand Down Expand Up @@ -796,6 +802,12 @@ uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7"
deps = ["Printf"]
uuid = "9abbd945-dff8-562f-b5e8-e1ebf5ef1b79"

[[deps.ProgressMeter]]
deps = ["Distributed", "Printf"]
git-tree-sha1 = "afadeba63d90ff223a6a48d2009434ecee2ec9e8"
uuid = "92933f4c-e287-5a05-a399-4b506db050ca"
version = "1.7.1"

[[deps.Qt5Base_jll]]
deps = ["Artifacts", "CompilerSupportLibraries_jll", "Fontconfig_jll", "Glib_jll", "JLLWrappers", "Libdl", "Libglvnd_jll", "OpenSSL_jll", "Pkg", "Xorg_libXext_jll", "Xorg_libxcb_jll", "Xorg_xcb_util_image_jll", "Xorg_xcb_util_keysyms_jll", "Xorg_xcb_util_renderutil_jll", "Xorg_xcb_util_wm_jll", "Zlib_jll", "xkbcommon_jll"]
git-tree-sha1 = "ad368663a5e20dbb8d6dc2fddeefe4dae0781ae8"
Expand Down
4 changes: 4 additions & 0 deletions examples/deep-kernel-learning/Project.toml
Original file line number Diff line number Diff line change
@@ -1,19 +1,23 @@
[deps]
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
AbstractGPs = "0.3"
Distributions = "0.25"
Flux = "0.12"
KernelFunctions = "0.10"
Literate = "2"
MLDataUtils = "0.5"
Plots = "1"
ProgressMeter = "1"
Zygote = "0.6"
julia = "1.3"
125 changes: 67 additions & 58 deletions examples/deep-kernel-learning/script.jl
Original file line number Diff line number Diff line change
@@ -1,68 +1,77 @@
# # Deep Kernel Learning
#
# !!! warning
# This example is under construction

# Setup

# # Deep Kernel Learning with Flux
# ## Package loading
# We use a couple of useful packages to plot and optimize
# the different hyper-parameters
using KernelFunctions
using MLDataUtils
using Zygote
using Flux
using Distributions, LinearAlgebra
using Plots
using ProgressMeter
using AbstractGPs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This also introduces a circular dependency and hence leads to the same problems as mentioned in #316 (comment). It might be better to move it to AbstractGPs or the JuliaGaussianProcesses webpage. We can always link to these examples from the documentation, so it would still be possible to make it discoverable from the KernelFunctions docs.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that each example has its own project environment & pinning, shouldn't this no longer be a concern? The example depends on KernelFunctions and AbstractGPs, and AbstractGPs depends on KernelFunctions, but neither KernelFunctions nor AbstractGPs depend on the example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

AbstractGPs depends on KernelFunctions

This is exactly the problem and the circular dependency I was referring to. If we make breaking changes in KernelFunctions, the example will break - there is no version of AbstractGPs that is compatible with this version yet. Hence either we have to switch to an old version of KernelFunctions and there will be a disconnect between the version of KernelFunctions discussed in the documentation and used in the example until AbstractGPs is updated and the example is fixed or we have to remove the example for a while. Both alternatives seem a bit annoying.

This problem is the main reason why I think it is unfortunate to use downstream packages in the documentation of KernelFunctions and these examples should rather go into AbstractGPs where we would not have to deal with these problems.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A breaking change doesn't imply the example must break, just that it's not guaranteed to not break... In any case, I'm fine with moving this notebook over to AbstractGPs. Just wanted to save the current state out of #234.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Often the example itself won't break but this is not relevant: AbstractGPs is just not compatible with KernelFunctions before and when (and possibly for some time after) a breaking release of KernelFunctions is made. A new release of AbstractGPs is needed with updated compatibility bounds and possibly some fixes, even if the example itself does not require any changes.

default(; legendfontsize=15.0, linewidth=3.0);

Flux.@functor SqExponentialKernel
Flux.@functor KernelSum
Flux.@functor Matern32Kernel
Flux.@functor FunctionTransform

# set up a kernel with a neural network feature extractor:

neuralnet = Chain(Dense(1, 3), Dense(3, 2))
k = SqExponentialKernel() ∘ FunctionTransform(neuralnet)

# Generate date

# ## Data creation
# We create a simple 1D Problem with very different variations
xmin = -3;
xmax = 3;
x = range(xmin, xmax; length=100)
x_test = rand(Uniform(xmin, xmax), 200)
x, y = noisy_function(sinc, x; noise=0.1)
X = RowVecs(reshape(x, :, 1))
X_test = RowVecs(reshape(x_test, :, 1))
λ = [0.1]
#md nothing #hide

#

f(x, k, λ) = kernelmatrix(k, x, X) / (kernelmatrix(k, X) + exp(λ[1]) * I) * y
f(X, k, 1.0)

#

loss(k, λ) = (ŷ -> sum(y - ŷ) / length(y) + exp(λ[1]) * norm(ŷ))(f(X, k, λ))
loss(k, λ)

#

xmax = 3; # Limits
N = 150
noise = 0.01
x_train = collect(eachrow(rand(Uniform(xmin, xmax), N))) # Training dataset
target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value
target_f(x::AbstractArray) = target_f(first(x))
y_train = target_f.(x_train) + randn(N) * noise
x_test = collect(eachrow(range(xmin, xmax; length=200))) # Testing dataset
spectral_mixture_kernel()
# ## Model definition
# We create a neural net with 2 layers and 10 units each
# The data is passed through the NN before being used in the kernel
neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5))
# We use two cases :
# - The Squared Exponential Kernel
k = transform(SqExponentialKernel(), FunctionTransform(neuralnet))

# We use AbstractGPs.jl to define our model
gpprior = GP(k) # GP Prior
fx = AbstractGPs.FiniteGP(gpprior, x_train, noise) # Prior on f
fp = posterior(fx, y_train) # Posterior of f

# This compute the log evidence of `y`,
# which is going to be used as the objective
loss(y) = -logpdf(fx, y)

@info "Init Loss = $(loss(y_train))"

# Flux will automatically extract all the parameters of the kernel
ps = Flux.params(k)
# push!(ps,λ)
opt = Flux.Momentum(1.0)
#md nothing #hide

#

plots = []
for i in 1:10
grads = Zygote.gradient(() -> loss(k, λ), ps)
# We show the initial prediction with the untrained model
p_init = Plots.plot(
vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))"
)
Plots.scatter!(vcat(x_train...), y_train; lab="data")
pred = marginals(fp(x_test))
Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction")
# ## Training
anim = Animation()
nmax = 1000
opt = Flux.ADAM(0.1)
@showprogress for i in 1:nmax
global grads = gradient(ps) do
loss(y_train)
end
Flux.Optimise.update!(opt, ps, grads)
p = Plots.scatter(x, y; lab="data", title="Loss = $(loss(k,λ))")
Plots.plot!(x, f(X, k, λ); lab="Prediction", lw=3.0)
push!(plots, p)
if i % 100 == 0
@info "$i/$nmax"
L = loss(y_train)
# @info "Loss = $L"
p = Plots.plot(
vcat(x_test...), target_f; lab="true f", title="Loss = $(loss(y_train))"
)
p = Plots.scatter!(vcat(x_train...), y_train; lab="data")
pred = marginals(posterior(fx, y_train)(x_test))
Plots.plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), lab="Prediction")
frame(anim)
display(p)
end
end

#

l = @layout grid(10, 1)
plot(plots...; layout=l, size=(300, 1500))
gif(anim; fps=5)