Skip to content

Commit 23dfd7a

Browse files
Copilotyebaigithub-actions[bot]avik-pal
authored
Replace Flux with Lux in deep kernel learning example (#435)
* Initial plan * Replace Flux with Lux in deep kernel learning example Co-authored-by: yebai <3279477+yebai@users.noreply.github.com> * Improve Lux implementation with proper parameter handling Co-authored-by: yebai <3279477+yebai@users.noreply.github.com> * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix Literate.jl parsing issue in deep kernel learning example Co-authored-by: yebai <3279477+yebai@users.noreply.github.com> * Clean up comments in script.jl * docs: use more of Lux official API for training and inference (#438) * Update examples/2-deep-kernel-learning/script.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: yebai <3279477+yebai@users.noreply.github.com> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Avik Pal <avik.pal.2017@gmail.com>
1 parent bae60f8 commit 23dfd7a

File tree

2 files changed

+57
-32
lines changed

2 files changed

+57
-32
lines changed
Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,24 @@
11
[deps]
22
AbstractGPs = "99985d1d-32ba-4be9-9821-2ec096f28918"
33
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
4-
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
54
KernelFunctions = "ec8451be-7e33-11e9-00cf-bbf324bd1392"
65
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
76
Literate = "98b081ad-f1c9-55d3-8b20-4c87d4299306"
7+
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
88
MLDataUtils = "cc2ba9b6-d476-5e6d-8eaf-a92d5412d41d"
9+
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
910
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
11+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1012
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1113

1214
[compat]
1315
AbstractGPs = "0.3,0.4,0.5"
1416
Distributions = "0.25"
15-
Flux = "0.12, 0.13, 0.14"
1617
KernelFunctions = "0.10"
1718
Literate = "2"
19+
Lux = "1"
1820
MLDataUtils = "0.5"
21+
Optimisers = "0.4"
1922
Plots = "1"
20-
Zygote = "0.6, 0.7"
21-
julia = "1.3"
23+
Zygote = "0.7"
24+
julia = "1.10"

examples/2-deep-kernel-learning/script.jl

Lines changed: 50 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
# # Deep Kernel Learning with Flux
1+
# # Deep Kernel Learning with Lux
22

33
## Background
44

55
# This example trains a GP whose inputs are passed through a neural network.
66
# This kind of model has been considered previously [^Calandra] [^Wilson], although it has been shown that some care is needed to avoid substantial overfitting [^Ober].
7-
# In this example we make use of the `FunctionTransform` from [KernelFunctions.jl](github.com/JuliaGaussianProcesses/KernelFunctions.jl/) to put a simple Multi-Layer Perceptron built using Flux.jl inside a standard kernel.
7+
# In this example we make use of the `FunctionTransform` from [KernelFunctions.jl](github.com/JuliaGaussianProcesses/KernelFunctions.jl/) to put a simple Multi-Layer Perceptron built using Lux.jl inside a standard kernel.
88

99
# [^Calandra]: Calandra, R., Peters, J., Rasmussen, C. E., & Deisenroth, M. P. (2016, July). [Manifold Gaussian processes for regression.](https://ieeexplore.ieee.org/abstract/document/7727626) In 2016 International Joint Conference on Neural Networks (IJCNN) (pp. 3338-3345). IEEE.
1010

@@ -17,35 +17,46 @@
1717
# the different hyper-parameters
1818
using AbstractGPs
1919
using Distributions
20-
using Flux
2120
using KernelFunctions
2221
using LinearAlgebra
22+
using Lux
23+
using Optimisers
2324
using Plots
25+
using Random
26+
using Zygote
2427
default(; legendfontsize=15.0, linewidth=3.0);
2528

29+
Random.seed!(42) # for reproducibility
30+
2631
# ## Data creation
2732
# We create a simple 1D Problem with very different variations
2833

2934
xmin, xmax = (-3, 3) # Limits
3035
N = 150
3136
noise_std = 0.01
3237
x_train_vec = rand(Uniform(xmin, xmax), N) # Training dataset
33-
x_train = collect(eachrow(x_train_vec)) # vector-of-vectors for Flux compatibility
38+
x_train = collect(eachrow(x_train_vec)) # vector-of-vectors for neural network compatibility
3439
target_f(x) = sinc(abs(x)^abs(x)) # We use sinc with a highly varying value
3540
y_train = target_f.(x_train_vec) + randn(N) * noise_std
3641
x_test_vec = range(xmin, xmax; length=200) # Testing dataset
37-
x_test = collect(eachrow(x_test_vec)) # vector-of-vectors for Flux compatibility
42+
x_test = collect(eachrow(x_test_vec)) # vector-of-vectors for neural network compatibility
3843

3944
plot(xmin:0.01:xmax, target_f; label="ground truth")
4045
scatter!(x_train_vec, y_train; label="training data")
4146

4247
# ## Model definition
4348
# We create a neural net with 2 layers and 10 units each.
4449
# The data is passed through the NN before being used in the kernel.
45-
neuralnet = Chain(Dense(1, 20), Dense(20, 30), Dense(30, 5))
50+
neuralnet = Chain(Dense(1 => 20), Dense(20 => 30), Dense(30 => 5))
51+
52+
# Initialize the neural network parameters
53+
rng = Random.default_rng()
54+
ps, st = Lux.setup(rng, neuralnet)
55+
56+
smodel = StatefulLuxLayer(neuralnet, ps, st)
4657

4758
# We use the Squared Exponential Kernel:
48-
k = SqExponentialKernel() FunctionTransform(neuralnet)
59+
k = SqExponentialKernel() FunctionTransform(smodel)
4960

5061
# We now define our model:
5162
gpprior = GP(k) # GP Prior
@@ -58,9 +69,6 @@ loss(y) = -logpdf(fx, y)
5869

5970
@info "Initial loss = $(loss(y_train))"
6071

61-
# Flux will automatically extract all the parameters of the kernel
62-
ps = Flux.params(k)
63-
6472
# We show the initial prediction with the untrained model
6573
p_init = plot(; title="Loss = $(round(loss(y_train); sigdigits=6))")
6674
plot!(vcat(x_test...), target_f; label="true f")
@@ -70,28 +78,42 @@ plot!(vcat(x_test...), mean.(pred_init); ribbon=std.(pred_init), label="Predicti
7078

7179
# ## Training
7280
nmax = 200
73-
opt = Flux.Adam(0.1)
81+
82+
# Create a wrapper function that updates the kernel with current parameters
83+
function update_kernel_and_loss(model, ps, st, data)
84+
smodel = StatefulLuxLayer(model, ps, st)
85+
k_updated = SqExponentialKernel() FunctionTransform(smodel)
86+
fx_updated = AbstractGPs.FiniteGP(GP(k_updated), x_train, noise_std^2)
87+
return -logpdf(fx_updated, y_train), smodel.st, (;)
88+
end
7489

7590
anim = Animation()
76-
for i in 1:nmax
77-
grads = gradient(ps) do
78-
loss(y_train)
79-
end
80-
Flux.Optimise.update!(opt, ps, grads)
81-
82-
if i % 10 == 0
83-
L = loss(y_train)
84-
@info "iteration $i/$nmax: loss = $L"
85-
86-
p = plot(; title="Loss[$i/$nmax] = $(round(L; sigdigits=6))")
87-
plot!(vcat(x_test...), target_f; label="true f")
88-
scatter!(vcat(x_train...), y_train; label="data")
89-
pred = marginals(posterior(fx, y_train)(x_test))
90-
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
91-
frame(anim)
92-
display(p)
91+
let tstate = Training.TrainState(neuralnet, ps, st, Optimisers.Adam(0.005))
92+
for i in 1:nmax
93+
_, loss_val, _, tstate = Training.single_train_step!(
94+
AutoZygote(), update_kernel_and_loss, (), tstate
95+
)
96+
97+
if i % 10 == 0
98+
k =
99+
SqExponentialKernel() FunctionTransform(
100+
StatefulLuxLayer(neuralnet, tstate.parameters, tstate.states)
101+
)
102+
fx = AbstractGPs.FiniteGP(GP(k), x_train, noise_std^2)
103+
104+
@info "iteration $i/$nmax: loss = $loss_val"
105+
106+
p = plot(; title="Loss[$i/$nmax] = $(round(loss_val; sigdigits=6))")
107+
plot!(vcat(x_test...), target_f; label="true f")
108+
scatter!(vcat(x_train...), y_train; label="data")
109+
pred = marginals(posterior(fx, y_train)(x_test))
110+
plot!(vcat(x_test...), mean.(pred); ribbon=std.(pred), label="Prediction")
111+
frame(anim)
112+
display(p)
113+
end
93114
end
94115
end
116+
95117
gif(anim, "train-dkl.gif"; fps=3)
96118
nothing #hide
97119

0 commit comments

Comments
 (0)