Skip to content

Commit

Permalink
Functional laplace memory investigation (aleximmer#2)
Browse files Browse the repository at this point in the history
* memory investigation start

* add for loop over output dimensions

* improve memory footprint

* minor

---------

Co-authored-by: Metod Jazbec <mjazbec@ivi-cn011.ivi.local>
  • Loading branch information
metodj and Metod Jazbec authored Feb 21, 2023
1 parent 005c3c7 commit e20ac6c
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 18 deletions.
1 change: 1 addition & 0 deletions examples/calibration_gp_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
prior_precision=prior_precision)
la.fit(train_loader)

print(f'Predicting Laplace-GP for m={m}')
probs_laplace = predict(test_loader, la, laplace=True, la_type='gp')
acc_laplace, ece_laplace, nll_laplace = get_metrics(probs_laplace, targets)
print(f'[Laplace-GP, m={m}] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}')
33 changes: 19 additions & 14 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -1183,15 +1183,15 @@ def _check_prior_precision(prior_precision):

def _init_K_MM(self):
if self.diagonal_kernel:
self.K_MM = [torch.zeros(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)]
self.K_MM = [torch.empty(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)]
else:
self.K_MM = torch.zeros(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device)
self.K_MM = torch.empty(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device)

def _init_Sigma_inv(self):
if self.diagonal_kernel:
self.Sigma_inv = [torch.zeros(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)]
self.Sigma_inv = [torch.empty(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)]
else:
self.Sigma_inv = torch.zeros(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device)
self.Sigma_inv = torch.empty(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device)

def _curv_closure(self, X, y):
return self.backend.gp_quantities(X, y)
Expand Down Expand Up @@ -1347,12 +1347,12 @@ def gp_posterior(self, X_star):
"""
Js, f_mu = self._jacobians(X_star)
f_var = self._gp_posterior_variance(Js, X_star)
f_var = self._gp_posterior_variance(Js)
if self.diagonal_kernel:
f_var = torch.diag_embed(f_var)
return f_mu.detach(), f_var.detach()

def _gp_posterior_variance(self, Js_star, X_star):
def _gp_posterior_variance(self, Js_star):
"""
GP posterior variance: \\( k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}\\)
Expand All @@ -1363,12 +1363,13 @@ def _gp_posterior_variance(self, Js_star, X_star):
X_star : torch.Tensor
test data points \\(X \in \mathbb{R}^{N_{test} \\times C} \\)
"""
K_star = self.gp_kernel_prior_variance * self._kernel_star(Js_star, X_star)
K_star = self.gp_kernel_prior_variance * self._kernel_star(Js_star)

K_M_star = []
for X_batch, _ in self.train_loader:
K_M_star_batch = self.gp_kernel_prior_variance * self._kernel_batch_star(Js_star, X_batch.to(self._device))
K_M_star.append(K_M_star_batch)
del X_batch

f_var = K_star - self._build_K_star_M(K_M_star)
return f_var
Expand Down Expand Up @@ -1495,32 +1496,34 @@ def _kernel_batch(self, jacobians, batch):
jacobians_2, _ = self._jacobians(batch)
P = jacobians.shape[-1] # nr model params
if self.diagonal_kernel:
kernel = torch.einsum('bcp,ecp->bec', jacobians, jacobians_2)
kernel = torch.empty((jacobians.shape[0], jacobians_2.shape[0], self.n_outputs), device=jacobians.device)
for c in range(self.n_outputs):
kernel[:, :, c] = torch.einsum('bp,ep->be', jacobians[:, c, :], jacobians_2[:, c, :])
else:
kernel = torch.einsum('ap,bp->ab', jacobians.reshape(-1, P), jacobians_2.reshape(-1, P))
del jacobians_2
return kernel

def _kernel_star(self, jacobians, batch):
def _kernel_star(self, jacobians):
"""
Compute K_star_star kernel matrix.
Parameters
----------
jacobians : torch.Tensor (b, C, P)
batch : torch.Tensor (b, C)
Returns
-------
kernel : torch.tensor
K_star with shape (b, C, C)
"""
jacobians_2, _ = self._jacobians(batch)
if self.diagonal_kernel:
kernel = torch.einsum('bcp,bcp->bc', jacobians, jacobians_2)
kernel = torch.empty((jacobians.shape[0], self.n_outputs), device=jacobians.device)
for c in range(self.n_outputs):
kernel[:, c] = torch.norm(jacobians[:, c, :], dim=1) ** 2
else:
kernel = torch.einsum('bcp,bep->bce', jacobians, jacobians_2)
kernel = torch.einsum('bcp,bep->bce', jacobians, jacobians)
return kernel

def _kernel_batch_star(self, jacobians, batch):
Expand All @@ -1539,7 +1542,9 @@ def _kernel_batch_star(self, jacobians, batch):
"""
jacobians_2, _ = self._jacobians(batch)
if self.diagonal_kernel:
kernel = torch.einsum('bcp,ecp->bec', jacobians, jacobians_2)
kernel = torch.empty((jacobians.shape[0], jacobians_2.shape[0], self.n_outputs), device=jacobians.device)
for c in range(self.n_outputs):
kernel[:, :, c] = torch.einsum('bp,ep->be', jacobians[:, c, :], jacobians_2[:, c, :])
else:
kernel = torch.einsum('bcp,dep->bdce', jacobians, jacobians_2)
return kernel
Expand Down
9 changes: 5 additions & 4 deletions laplace/curvature/backpack.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch

from torch.nn.utils import parameters_to_vector
from backpack import backpack, extend, memory_cleanup
from backpack.extensions import DiagGGNExact, DiagGGNMC, KFAC, KFLR, SumGradSquared, BatchGrad
from backpack.context import CTX
Expand Down Expand Up @@ -33,7 +33,8 @@ def jacobians(self, x):
output function `(batch, outputs)`
"""
model = extend(self.model)
to_stack = []
P = len(parameters_to_vector(model.parameters()).detach())
Jks = torch.empty(model.output_size, x.shape[0], P, device=x.device) # (C, b, P)
for i in range(model.output_size):
model.zero_grad()
out = model(x)
Expand All @@ -49,15 +50,15 @@ def jacobians(self, x):
Jk = torch.cat(to_cat, dim=1)
if self.subnetwork_indices is not None:
Jk = Jk[:, self.subnetwork_indices]
to_stack.append(Jk)
Jks[i] = Jk
if i == 0:
f = out.detach()

model.zero_grad()
CTX.remove_hooks()
_cleanup(model)
if model.output_size > 1:
return torch.stack(to_stack, dim=2).transpose(1, 2), f
return Jks.transpose(0, 1), f
else:
return Jk.unsqueeze(-1).transpose(1, 2), f

Expand Down

0 comments on commit e20ac6c

Please sign in to comment.