Skip to content

Commit d440848

Browse files
author
Frederik Rahbaek Warburg
committed
much much faster!
1 parent 343014b commit d440848

File tree

3 files changed

+37
-28
lines changed

3 files changed

+37
-28
lines changed

examples/examples_laplace/data.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,8 @@ def get_data(name, batch_size = 32):
1010
dataset = MNIST('', train=True, download=True, transform=transforms.ToTensor())
1111
mnist_train, mnist_val = random_split(dataset, [55000, 5000])
1212

13-
train_loader = DataLoader(mnist_train, batch_size=batch_size)
14-
val_loader = DataLoader(mnist_val, batch_size=batch_size)
13+
train_loader = DataLoader(mnist_train, batch_size=batch_size, pin_memory=True)
14+
val_loader = DataLoader(mnist_val, batch_size=batch_size, pin_memory=True)
1515

1616
elif name == "swissrole":
1717
N_train = 50000
@@ -27,8 +27,8 @@ def swiss_roll_2d(noise=0.2, n_samples=100):
2727
X_train, y_train = swiss_roll_2d(n_samples=N_train)
2828
X_val, y_test = swiss_roll_2d(n_samples=N_val)
2929

30-
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size)
31-
val_loader = DataLoader(TensorDataset(X_val, y_test), batch_size=batch_size)
30+
train_loader = DataLoader(TensorDataset(X_train, y_train), batch_size=batch_size, pin_memory=True)
31+
val_loader = DataLoader(TensorDataset(X_val, y_test), batch_size=batch_size, pin_memory=True)
3232

3333
else:
3434
raise NotImplemplenetError

examples/examples_laplace/trainer_e2e.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from trainer_ae import train_ae
1+
from trainer_ae import train_ae, test_ae
22
from trainer_lae import train_lae, test_lae
33

44

examples/examples_laplace/trainer_lae.py

Lines changed: 32 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import torch
99
from torch.utils.data import DataLoader, TensorDataset
1010
from tqdm import tqdm
11+
import time
1112

1213
from laplace import Laplace
1314
from data import get_data
@@ -40,32 +41,35 @@ def test_lae(dataset, batch_size=1):
4041

4142
train_loader, val_loader = get_data(dataset, batch_size)
4243

44+
pred_type = "nn"
45+
4346
# forward eval la
44-
x, z, labels, mu_rec, sigma_rec = [], [], [], [], []
47+
x, z_list, labels, mu_rec, sigma_rec = [], [], [], [], []
4548
for i, (X, y) in tqdm(enumerate(val_loader)):
46-
X = X.view(X.size(0), -1).to(device)
47-
with torch.inference_mode():
48-
z += [encoder(X)]
49+
t0 = time.time()
50+
with torch.no_grad():
51+
52+
X = X.view(X.size(0), -1).to(device)
53+
z = encoder(X)
54+
55+
mu, var = la(z, pred_type = pred_type)
4956

50-
# pred_type : {glm, nn}
51-
# link_approx only relevant for classification
52-
pred_type = "glm"
53-
mu, var = la(z[-1], pred_type = pred_type)
57+
mu_rec += [mu.detach()]
58+
sigma_rec += [var.sqrt()]
5459

55-
x += [X.cpu()]
60+
x += [X]
5661
labels += [y]
57-
mu_rec += [mu.detach().cpu()]
58-
sigma_rec += [var.sqrt().cpu()]
62+
z_list += [z]
5963

6064
# only show the first 50 points
6165
# if i > 50:
6266
# break
6367

64-
x = torch.cat(x, dim=0).numpy()
68+
x = torch.cat(x, dim=0).cpu().numpy()
6569
labels = torch.cat(labels, dim=0).numpy()
66-
z = torch.cat(z, dim=0).cpu().numpy()
67-
mu_rec = torch.cat(mu_rec, dim=0).numpy()
68-
sigma_rec = torch.cat(sigma_rec, dim=0).numpy()
70+
z = torch.cat(z_list, dim=0).cpu().numpy()
71+
mu_rec = torch.cat(mu_rec, dim=0).cpu().numpy()
72+
sigma_rec = torch.cat(sigma_rec, dim=0).cpu().numpy()
6973

7074
###
7175
# Grid for probability map
@@ -81,17 +85,21 @@ def test_lae(dataset, batch_size=1):
8185
xg = xg_mesh.reshape(n_points_axis ** 2, 1)
8286
yg = yg_mesh.reshape(n_points_axis ** 2, 1)
8387
Z_grid_test = np.hstack((xg, yg))
84-
Z_grid_test = torch.from_numpy(Z_grid_test).to(device)
88+
Z_grid_test = torch.from_numpy(Z_grid_test)
89+
90+
z_grid_loader = DataLoader(TensorDataset(Z_grid_test), batch_size=batch_size, pin_memory=True)
8591

8692
all_f_mu, all_f_sigma = [], []
87-
for i in tqdm(range(Z_grid_test.shape[0])):
88-
f_mu, f_var = la(Z_grid_test[i:i+1,:], pred_type = pred_type)
93+
for z_grid in tqdm(z_grid_loader):
94+
95+
z_grid = z_grid[0].to(device)
96+
f_mu, f_var = la(z_grid, pred_type = pred_type)
8997

9098
all_f_mu += [f_mu.squeeze().detach().cpu()]
9199
all_f_sigma += [f_var.squeeze().sqrt().cpu()]
92100

93-
f_mu = torch.stack(all_f_mu, dim=0)
94-
f_sigma = torch.stack(all_f_sigma, dim=0)
101+
f_mu = torch.cat(all_f_mu, dim=0)
102+
f_sigma = torch.cat(all_f_sigma, dim=0)
95103

96104
# get diagonal elements
97105
idx = torch.arange(f_sigma.shape[1])
@@ -157,7 +165,7 @@ def train_lae(dataset="mnist", n_epochs=50, batch_size=32):
157165
z = torch.cat(z, dim=0).cpu()
158166
x = torch.cat(x, dim=0).cpu()
159167

160-
z_loader = DataLoader(TensorDataset(z, x), batch_size=batch_size)
168+
z_loader = DataLoader(TensorDataset(z, x), batch_size=batch_size, pin_memory=True)
161169

162170
# Laplace Approximation
163171
la = Laplace(decoder, 'regression', subset_of_weights='last_layer', hessian_structure='diag')
@@ -182,7 +190,7 @@ def train_lae(dataset="mnist", n_epochs=50, batch_size=32):
182190

183191
train = False
184192
dataset = "mnist"
185-
batch_size = 1
193+
batch_size = 128
186194

187195
# train or load laplace auto encoder
188196
if train:
@@ -194,4 +202,5 @@ def train_lae(dataset="mnist", n_epochs=50, batch_size=32):
194202

195203
# evaluate laplace auto encoder
196204
print("==> evaluate lae")
197-
test_lae(dataset)
205+
test_lae(dataset, batch_size)
206+

0 commit comments

Comments
 (0)