|
7 | 7 | import numpy as np
|
8 | 8 | import torch
|
9 | 9 | from torch.utils.data import DataLoader, TensorDataset
|
10 |
| - |
| 10 | +from tqdm import tqdm |
11 | 11 | from ae_models import AE_mnist
|
12 | 12 |
|
13 | 13 | from laplace import Laplace
|
14 | 14 |
|
15 |
| -plt.rc('text', usetex=True) |
16 |
| -plt.rc('text.latex', preamble=r'\usepackage{amssymb} \usepackage{amsmath} \usepackage{marvosym}') |
17 |
| -plt.rc('font', family='serif') |
18 |
| -plt.rcParams.update({'font.size': 12}) |
19 |
| - |
20 |
| -N = 50000 |
21 |
| -N_test = 300 |
22 |
| -n_epochs = 30 |
| 15 | +n_epochs = 50 |
23 | 16 | batch_size = 128 # full batch
|
24 | 17 | true_sigma_noise = 0.3
|
25 | 18 |
|
|
29 | 22 |
|
30 | 23 | mnist = MNIST('../data/', download=True)
|
31 | 24 | X_train = mnist.train_data.reshape(-1, 784).numpy() / 255.0
|
32 |
| -y_train = mnist.train_labels.numpy() |
| 25 | +y_train = torch.from_numpy(mnist.train_labels.numpy()) |
33 | 26 |
|
34 | 27 | X_test = mnist.test_data.reshape(-1, 784).numpy() / 255.0
|
35 |
| -y_test = mnist.test_labels.numpy() |
| 28 | +y_test = torch.from_numpy(mnist.test_labels.numpy()) |
36 | 29 |
|
37 | 30 | X_train = torch.from_numpy(X_train).float()
|
38 | 31 | X_test = torch.from_numpy(X_test).float()
|
39 | 32 |
|
40 | 33 | train_loader = DataLoader(TensorDataset(X_train), batch_size=batch_size)
|
41 | 34 |
|
42 |
| -model = AE_mnist(latent_size=2) |
| 35 | +device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| 36 | +model = AE_mnist(latent_size=2, device="gpu" if torch.cuda.is_available() else "cpu") |
43 | 37 | model.fit(X_train, n_epochs=n_epochs, learning_rate=1e-3, batch_size=batch_size, verbose=True, labels=y_train)
|
44 | 38 |
|
45 | 39 | # Visualize Latent Space
|
46 |
| -X_test_fold = X_test.view(X_test.shape[0], -1) |
| 40 | +X_test_fold = X_test.view(X_test.shape[0], -1).to(device) |
47 | 41 | z_test = model.encoder(X_test_fold)
|
48 | 42 | z_test = z_test.detach()
|
49 | 43 |
|
50 | 44 | # Laplace Approximation
|
51 |
| -la = Laplace(model.decoder, 'regression', subset_of_weights='all', hessian_structure='full') |
| 45 | +la = Laplace(model.decoder, 'regression', subset_of_weights='last_layer', hessian_structure='diag') |
52 | 46 |
|
53 | 47 | # Getting Z representations for X_train
|
54 |
| -X_fold = X_train.view(X_train.shape[0], -1) |
| 48 | +X_fold = X_train.view(X_train.shape[0], -1).to(device) |
55 | 49 | model.eval()
|
56 | 50 | with torch.inference_mode():
|
57 | 51 | z = model.encoder(X_fold)
|
58 | 52 | x_rec = model.decoder(z)
|
59 | 53 | z_loader = DataLoader(TensorDataset(z, x_rec), batch_size=batch_size)
|
60 | 54 | # Fitting
|
61 | 55 | la.fit(z_loader)
|
62 |
| -log_prior, log_sigma = torch.ones(1, requires_grad=True), torch.ones(1, requires_grad=True) |
| 56 | +log_prior, log_sigma = torch.ones(1, requires_grad=True, device="cuda:0"), torch.ones(1, requires_grad=True, device="cuda:0") |
63 | 57 | hyper_optimizer = torch.optim.Adam([log_prior, log_sigma], lr=1e-2)
|
64 | 58 | for i in range(n_epochs):
|
65 | 59 | hyper_optimizer.zero_grad()
|
66 | 60 | neg_marglik = - la.log_marginal_likelihood(log_prior.exp(), log_sigma.exp())
|
67 | 61 | neg_marglik.backward()
|
68 | 62 | hyper_optimizer.step()
|
69 | 63 |
|
70 |
| -plt.figure() |
71 |
| -fig, ax = plt.subplots(dpi=150) |
72 |
| - |
73 |
| -X_test_fold = X_test.view(X_test.shape[0], -1) |
| 64 | +X_test_fold = X_test.view(X_test.shape[0], -1).to(device) |
74 | 65 | z_test = model.encoder(X_test_fold)
|
75 | 66 | z_test = z_test.detach()
|
76 | 67 |
|
77 | 68 | # GRID FOR PROBABILITY MAP
|
78 | 69 | n_points_axis = 50
|
79 |
| -zx_grid = np.linspace(z_test[:,0].min().detach().numpy() - 1.5, z_test[:,0].max().detach().numpy() + 1.5, n_points_axis) |
80 |
| -zy_grid = np.linspace(z_test[:,1].min().detach().numpy() - 1.5, z_test[:,1].max().detach().numpy() + 1.5, n_points_axis) |
| 70 | +zx_grid = np.linspace(z_test[:,0].min().cpu().detach().numpy() - 1.5, z_test[:,0].max().cpu().detach().numpy() + 1.5, n_points_axis) |
| 71 | +zy_grid = np.linspace(z_test[:,1].min().cpu().detach().numpy() - 1.5, z_test[:,1].max().cpu().detach().numpy() + 1.5, n_points_axis) |
81 | 72 |
|
82 | 73 | xg_mesh, yg_mesh = np.meshgrid(zx_grid, zy_grid)
|
83 | 74 | xg = xg_mesh.reshape(n_points_axis ** 2, 1)
|
84 | 75 | yg = yg_mesh.reshape(n_points_axis ** 2, 1)
|
85 | 76 | Z_grid_test = np.hstack((xg, yg))
|
86 |
| -Z_grid_test = torch.from_numpy(Z_grid_test).float().detach() |
| 77 | +Z_grid_test = torch.from_numpy(Z_grid_test).float().detach().to(device) |
87 | 78 |
|
88 |
| -f_mu, f_var = la(Z_grid_test) |
89 |
| -f_mu = f_mu.squeeze().detach().cpu().numpy() |
90 |
| -f_sigma = f_var.squeeze().sqrt().cpu().numpy() |
| 79 | +all_f_mu, all_f_sigma = [], [] |
| 80 | +for i in tqdm(range(Z_grid_test.shape[0])): |
| 81 | + f_mu, f_var = la(Z_grid_test[i:i+1,:]) |
| 82 | + f_mu = f_mu.squeeze().detach().cpu().numpy() |
| 83 | + f_sigma = f_var.squeeze().sqrt().cpu().numpy() |
91 | 84 |
|
92 |
| -sigma_vector = f_sigma[:,0,0] |
| 85 | + all_f_mu.append(f_mu) |
| 86 | + all_f_sigma.append(f_sigma) |
93 | 87 |
|
94 |
| -plt.plot(z_test[:, 0], z_test[:, 1], 'wx', ms=5.0, alpha=1.0) |
| 88 | +f_mu = np.stack(all_f_mu) |
| 89 | +f_sigma = np.stack(all_f_sigma) |
95 | 90 |
|
| 91 | +sigma_vector = f_sigma[:,np.arange(f_sigma.shape[1]), np.arange(f_sigma.shape[1])].mean(axis=1) |
| 92 | + |
| 93 | +plt.figure() |
| 94 | +plt.plot(z_test[:, 0].cpu(), z_test[:, 1].cpu(), 'wx', ms=5.0, alpha=1.0) |
96 | 95 | precision_grid = np.reshape(sigma_vector, (n_points_axis,n_points_axis))
|
97 | 96 | plt.contourf(xg_mesh, yg_mesh, precision_grid, cmap='viridis_r')
|
98 | 97 | plt.colorbar()
|
99 |
| - |
100 |
| -plt.title(f'Data N={N}') |
| 98 | +plt.savefig("mnist.png") |
101 | 99 | plt.show()
|
102 |
| - |
| 100 | + |
0 commit comments