Skip to content

Commit 49221ed

Browse files
update mnist
1 parent c4457b2 commit 49221ed

File tree

1 file changed

+29
-31
lines changed

1 file changed

+29
-31
lines changed

examples/examples_laplace/laplace_ae_mnist.py

Lines changed: 29 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,19 +7,12 @@
77
import numpy as np
88
import torch
99
from torch.utils.data import DataLoader, TensorDataset
10-
10+
from tqdm import tqdm
1111
from ae_models import AE_mnist
1212

1313
from laplace import Laplace
1414

15-
plt.rc('text', usetex=True)
16-
plt.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath} \usepackage{marvosym}')
17-
plt.rc('font', family='serif')
18-
plt.rcParams.update({'font.size': 12})
19-
20-
N = 50000
21-
N_test = 300
22-
n_epochs = 30
15+
n_epochs = 50
2316
batch_size = 128 # full batch
2417
true_sigma_noise = 0.3
2518

@@ -29,74 +22,79 @@
2922

3023
mnist = MNIST('../data/', download=True)
3124
X_train = mnist.train_data.reshape(-1, 784).numpy() / 255.0
32-
y_train = mnist.train_labels.numpy()
25+
y_train = torch.from_numpy(mnist.train_labels.numpy())
3326

3427
X_test = mnist.test_data.reshape(-1, 784).numpy() / 255.0
35-
y_test = mnist.test_labels.numpy()
28+
y_test = torch.from_numpy(mnist.test_labels.numpy())
3629

3730
X_train = torch.from_numpy(X_train).float()
3831
X_test = torch.from_numpy(X_test).float()
3932

4033
train_loader = DataLoader(TensorDataset(X_train), batch_size=batch_size)
4134

42-
model = AE_mnist(latent_size=2)
35+
device = "cuda:0" if torch.cuda.is_available() else "cpu"
36+
model = AE_mnist(latent_size=2, device="gpu" if torch.cuda.is_available() else "cpu")
4337
model.fit(X_train, n_epochs=n_epochs, learning_rate=1e-3, batch_size=batch_size, verbose=True, labels=y_train)
4438

4539
# Visualize Latent Space
46-
X_test_fold = X_test.view(X_test.shape[0], -1)
40+
X_test_fold = X_test.view(X_test.shape[0], -1).to(device)
4741
z_test = model.encoder(X_test_fold)
4842
z_test = z_test.detach()
4943

5044
# Laplace Approximation
51-
la = Laplace(model.decoder, 'regression', subset_of_weights='all', hessian_structure='full')
45+
la = Laplace(model.decoder, 'regression', subset_of_weights='last_layer', hessian_structure='diag')
5246

5347
# Getting Z representations for X_train
54-
X_fold = X_train.view(X_train.shape[0], -1)
48+
X_fold = X_train.view(X_train.shape[0], -1).to(device)
5549
model.eval()
5650
with torch.inference_mode():
5751
z = model.encoder(X_fold)
5852
x_rec = model.decoder(z)
5953
z_loader = DataLoader(TensorDataset(z, x_rec), batch_size=batch_size)
6054
# Fitting
6155
la.fit(z_loader)
62-
log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True)
56+
log_prior, log_sigma = torch.ones(1, requires_grad=True, device="cuda:0"), torch.ones(1, requires_grad=True, device="cuda:0")
6357
hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-2)
6458
for i in range(n_epochs):
6559
hyper_optimizer.zero_grad()
6660
neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
6761
neg_marglik.backward()
6862
hyper_optimizer.step()
6963

70-
plt.figure()
71-
fig, ax = plt.subplots(dpi=150)
72-
73-
X_test_fold = X_test.view(X_test.shape[0], -1)
64+
X_test_fold = X_test.view(X_test.shape[0], -1).to(device)
7465
z_test = model.encoder(X_test_fold)
7566
z_test = z_test.detach()
7667

7768
# GRID FOR PROBABILITY MAP
7869
n_points_axis = 50
79-
zx_grid = np.linspace(z_test[:,0].min().detach().numpy() - 1.5, z_test[:,0].max().detach().numpy() + 1.5, n_points_axis)
80-
zy_grid = np.linspace(z_test[:,1].min().detach().numpy() - 1.5, z_test[:,1].max().detach().numpy() + 1.5, n_points_axis)
70+
zx_grid = np.linspace(z_test[:,0].min().cpu().detach().numpy() - 1.5, z_test[:,0].max().cpu().detach().numpy() + 1.5, n_points_axis)
71+
zy_grid = np.linspace(z_test[:,1].min().cpu().detach().numpy() - 1.5, z_test[:,1].max().cpu().detach().numpy() + 1.5, n_points_axis)
8172

8273
xg_mesh, yg_mesh = np.meshgrid(zx_grid, zy_grid)
8374
xg = xg_mesh.reshape(n_points_axis ** 2, 1)
8475
yg = yg_mesh.reshape(n_points_axis ** 2, 1)
8576
Z_grid_test = np.hstack((xg, yg))
86-
Z_grid_test = torch.from_numpy(Z_grid_test).float().detach()
77+
Z_grid_test = torch.from_numpy(Z_grid_test).float().detach().to(device)
8778

88-
f_mu, f_var = la(Z_grid_test)
89-
f_mu = f_mu.squeeze().detach().cpu().numpy()
90-
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
79+
all_f_mu, all_f_sigma = [], []
80+
for i in tqdm(range(Z_grid_test.shape[0])):
81+
f_mu, f_var = la(Z_grid_test[i:i+1,:])
82+
f_mu = f_mu.squeeze().detach().cpu().numpy()
83+
f_sigma = f_var.squeeze().sqrt().cpu().numpy()
9184

92-
sigma_vector = f_sigma[:,0,0]
85+
all_f_mu.append(f_mu)
86+
all_f_sigma.append(f_sigma)
9387

94-
plt.plot(z_test[:, 0], z_test[:, 1], 'wx', ms=5.0, alpha=1.0)
88+
f_mu = np.stack(all_f_mu)
89+
f_sigma = np.stack(all_f_sigma)
9590

91+
sigma_vector = f_sigma[:,np.arange(f_sigma.shape[1]), np.arange(f_sigma.shape[1])].mean(axis=1)
92+
93+
plt.figure()
94+
plt.plot(z_test[:, 0].cpu(), z_test[:, 1].cpu(), 'wx', ms=5.0, alpha=1.0)
9695
precision_grid = np.reshape(sigma_vector, (n_points_axis,n_points_axis))
9796
plt.contourf(xg_mesh, yg_mesh, precision_grid, cmap='viridis_r')
9897
plt.colorbar()
99-
100-
plt.title(f'Data N={N}')
98+
plt.savefig("mnist.png")
10199
plt.show()
102-
100+

0 commit comments

Comments
 (0)