8
8
import torch
9
9
from torch .utils .data import DataLoader , TensorDataset
10
10
from tqdm import tqdm
11
+ import time
11
12
12
13
from laplace import Laplace
13
14
from data import get_data
@@ -40,32 +41,35 @@ def test_lae(dataset, batch_size=1):
40
41
41
42
train_loader , val_loader = get_data (dataset , batch_size )
42
43
44
+ pred_type = "nn"
45
+
43
46
# forward eval la
44
- x , z , labels , mu_rec , sigma_rec = [], [], [], [], []
47
+ x , z_list , labels , mu_rec , sigma_rec = [], [], [], [], []
45
48
for i , (X , y ) in tqdm (enumerate (val_loader )):
46
- X = X .view (X .size (0 ), - 1 ).to (device )
47
- with torch .inference_mode ():
48
- z += [encoder (X )]
49
+ t0 = time .time ()
50
+ with torch .no_grad ():
51
+
52
+ X = X .view (X .size (0 ), - 1 ).to (device )
53
+ z = encoder (X )
54
+
55
+ mu , var = la (z , pred_type = pred_type )
49
56
50
- # pred_type : {glm, nn}
51
- # link_approx only relevant for classification
52
- pred_type = "glm"
53
- mu , var = la (z [- 1 ], pred_type = pred_type )
57
+ mu_rec += [mu .detach ()]
58
+ sigma_rec += [var .sqrt ()]
54
59
55
- x += [X . cpu () ]
60
+ x += [X ]
56
61
labels += [y ]
57
- mu_rec += [mu .detach ().cpu ()]
58
- sigma_rec += [var .sqrt ().cpu ()]
62
+ z_list += [z ]
59
63
60
64
# only show the first 50 points
61
65
# if i > 50:
62
66
# break
63
67
64
- x = torch .cat (x , dim = 0 ).numpy ()
68
+ x = torch .cat (x , dim = 0 ).cpu (). numpy ()
65
69
labels = torch .cat (labels , dim = 0 ).numpy ()
66
- z = torch .cat (z , dim = 0 ).cpu ().numpy ()
67
- mu_rec = torch .cat (mu_rec , dim = 0 ).numpy ()
68
- sigma_rec = torch .cat (sigma_rec , dim = 0 ).numpy ()
70
+ z = torch .cat (z_list , dim = 0 ).cpu ().numpy ()
71
+ mu_rec = torch .cat (mu_rec , dim = 0 ).cpu (). numpy ()
72
+ sigma_rec = torch .cat (sigma_rec , dim = 0 ).cpu (). numpy ()
69
73
70
74
###
71
75
# Grid for probability map
@@ -81,17 +85,21 @@ def test_lae(dataset, batch_size=1):
81
85
xg = xg_mesh .reshape (n_points_axis ** 2 , 1 )
82
86
yg = yg_mesh .reshape (n_points_axis ** 2 , 1 )
83
87
Z_grid_test = np .hstack ((xg , yg ))
84
- Z_grid_test = torch .from_numpy (Z_grid_test ).to (device )
88
+ Z_grid_test = torch .from_numpy (Z_grid_test )
89
+
90
+ z_grid_loader = DataLoader (TensorDataset (Z_grid_test ), batch_size = batch_size , pin_memory = True )
85
91
86
92
all_f_mu , all_f_sigma = [], []
87
- for i in tqdm (range (Z_grid_test .shape [0 ])):
88
- f_mu , f_var = la (Z_grid_test [i :i + 1 ,:], pred_type = pred_type )
93
+ for z_grid in tqdm (z_grid_loader ):
94
+
95
+ z_grid = z_grid [0 ].to (device )
96
+ f_mu , f_var = la (z_grid , pred_type = pred_type )
89
97
90
98
all_f_mu += [f_mu .squeeze ().detach ().cpu ()]
91
99
all_f_sigma += [f_var .squeeze ().sqrt ().cpu ()]
92
100
93
- f_mu = torch .stack (all_f_mu , dim = 0 )
94
- f_sigma = torch .stack (all_f_sigma , dim = 0 )
101
+ f_mu = torch .cat (all_f_mu , dim = 0 )
102
+ f_sigma = torch .cat (all_f_sigma , dim = 0 )
95
103
96
104
# get diagonal elements
97
105
idx = torch .arange (f_sigma .shape [1 ])
@@ -157,7 +165,7 @@ def train_lae(dataset="mnist", n_epochs=50, batch_size=32):
157
165
z = torch .cat (z , dim = 0 ).cpu ()
158
166
x = torch .cat (x , dim = 0 ).cpu ()
159
167
160
- z_loader = DataLoader (TensorDataset (z , x ), batch_size = batch_size )
168
+ z_loader = DataLoader (TensorDataset (z , x ), batch_size = batch_size , pin_memory = True )
161
169
162
170
# Laplace Approximation
163
171
la = Laplace (decoder , 'regression' , subset_of_weights = 'last_layer' , hessian_structure = 'diag' )
@@ -182,7 +190,7 @@ def train_lae(dataset="mnist", n_epochs=50, batch_size=32):
182
190
183
191
train = False
184
192
dataset = "mnist"
185
- batch_size = 1
193
+ batch_size = 128
186
194
187
195
# train or load laplace auto encoder
188
196
if train :
@@ -194,4 +202,5 @@ def train_lae(dataset="mnist", n_epochs=50, batch_size=32):
194
202
195
203
# evaluate laplace auto encoder
196
204
print ("==> evaluate lae" )
197
- test_lae (dataset )
205
+ test_lae (dataset , batch_size )
206
+
0 commit comments