From e20ac6c56bae5fbd147e49d03948587bf865880d Mon Sep 17 00:00:00 2001 From: Metod Jazbec Date: Tue, 21 Feb 2023 15:22:00 +0100 Subject: [PATCH] Functional laplace memory investigation (#2) * memory investigation start * add for loop over output dimensions * improve memory footprint * minor --------- Co-authored-by: Metod Jazbec --- examples/calibration_gp_example.py | 1 + laplace/baselaplace.py | 33 +++++++++++++++++------------- laplace/curvature/backpack.py | 9 ++++---- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/examples/calibration_gp_example.py b/examples/calibration_gp_example.py index 9d3ae101..aebfe4bf 100644 --- a/examples/calibration_gp_example.py +++ b/examples/calibration_gp_example.py @@ -43,6 +43,7 @@ prior_precision=prior_precision) la.fit(train_loader) + print(f'Predicting Laplace-GP for m={m}') probs_laplace = predict(test_loader, la, laplace=True, la_type='gp') acc_laplace, ece_laplace, nll_laplace = get_metrics(probs_laplace, targets) print(f'[Laplace-GP, m={m}] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}') diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 6d003789..ed5830f3 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -1183,15 +1183,15 @@ def _check_prior_precision(prior_precision): def _init_K_MM(self): if self.diagonal_kernel: - self.K_MM = [torch.zeros(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)] + self.K_MM = [torch.empty(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)] else: - self.K_MM = torch.zeros(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device) + self.K_MM = torch.empty(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device) def _init_Sigma_inv(self): if self.diagonal_kernel: - self.Sigma_inv = [torch.zeros(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)] + self.Sigma_inv = [torch.empty(size=(self.M, self.M), device=self._device) for _ in range(self.n_outputs)] else: - self.Sigma_inv = torch.zeros(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device) + self.Sigma_inv = torch.empty(size=(self.M * self.n_outputs, self.M * self.n_outputs), device=self._device) def _curv_closure(self, X, y): return self.backend.gp_quantities(X, y) @@ -1347,12 +1347,12 @@ def gp_posterior(self, X_star): """ Js, f_mu = self._jacobians(X_star) - f_var = self._gp_posterior_variance(Js, X_star) + f_var = self._gp_posterior_variance(Js) if self.diagonal_kernel: f_var = torch.diag_embed(f_var) return f_mu.detach(), f_var.detach() - def _gp_posterior_variance(self, Js_star, X_star): + def _gp_posterior_variance(self, Js_star): """ GP posterior variance: \\( k_{**} - K_{*M} (K_{MM}+ L_{MM}^{-1})^{-1} K_{M*}\\) @@ -1363,12 +1363,13 @@ def _gp_posterior_variance(self, Js_star, X_star): X_star : torch.Tensor test data points \\(X \in \mathbb{R}^{N_{test} \\times C} \\) """ - K_star = self.gp_kernel_prior_variance * self._kernel_star(Js_star, X_star) + K_star = self.gp_kernel_prior_variance * self._kernel_star(Js_star) K_M_star = [] for X_batch, _ in self.train_loader: K_M_star_batch = self.gp_kernel_prior_variance * self._kernel_batch_star(Js_star, X_batch.to(self._device)) K_M_star.append(K_M_star_batch) + del X_batch f_var = K_star - self._build_K_star_M(K_M_star) return f_var @@ -1495,20 +1496,21 @@ def _kernel_batch(self, jacobians, batch): jacobians_2, _ = self._jacobians(batch) P = jacobians.shape[-1] # nr model params if self.diagonal_kernel: - kernel = torch.einsum('bcp,ecp->bec', jacobians, jacobians_2) + kernel = torch.empty((jacobians.shape[0], jacobians_2.shape[0], self.n_outputs), device=jacobians.device) + for c in range(self.n_outputs): + kernel[:, :, c] = torch.einsum('bp,ep->be', jacobians[:, c, :], jacobians_2[:, c, :]) else: kernel = torch.einsum('ap,bp->ab', jacobians.reshape(-1, P), jacobians_2.reshape(-1, P)) del jacobians_2 return kernel - def _kernel_star(self, jacobians, batch): + def _kernel_star(self, jacobians): """ Compute K_star_star kernel matrix. Parameters ---------- jacobians : torch.Tensor (b, C, P) - batch : torch.Tensor (b, C) Returns ------- @@ -1516,11 +1518,12 @@ def _kernel_star(self, jacobians, batch): K_star with shape (b, C, C) """ - jacobians_2, _ = self._jacobians(batch) if self.diagonal_kernel: - kernel = torch.einsum('bcp,bcp->bc', jacobians, jacobians_2) + kernel = torch.empty((jacobians.shape[0], self.n_outputs), device=jacobians.device) + for c in range(self.n_outputs): + kernel[:, c] = torch.norm(jacobians[:, c, :], dim=1) ** 2 else: - kernel = torch.einsum('bcp,bep->bce', jacobians, jacobians_2) + kernel = torch.einsum('bcp,bep->bce', jacobians, jacobians) return kernel def _kernel_batch_star(self, jacobians, batch): @@ -1539,7 +1542,9 @@ def _kernel_batch_star(self, jacobians, batch): """ jacobians_2, _ = self._jacobians(batch) if self.diagonal_kernel: - kernel = torch.einsum('bcp,ecp->bec', jacobians, jacobians_2) + kernel = torch.empty((jacobians.shape[0], jacobians_2.shape[0], self.n_outputs), device=jacobians.device) + for c in range(self.n_outputs): + kernel[:, :, c] = torch.einsum('bp,ep->be', jacobians[:, c, :], jacobians_2[:, c, :]) else: kernel = torch.einsum('bcp,dep->bdce', jacobians, jacobians_2) return kernel diff --git a/laplace/curvature/backpack.py b/laplace/curvature/backpack.py index 387aa4d0..cc2e66c4 100644 --- a/laplace/curvature/backpack.py +++ b/laplace/curvature/backpack.py @@ -1,5 +1,5 @@ import torch - +from torch.nn.utils import parameters_to_vector from backpack import backpack, extend, memory_cleanup from backpack.extensions import DiagGGNExact, DiagGGNMC, KFAC, KFLR, SumGradSquared, BatchGrad from backpack.context import CTX @@ -33,7 +33,8 @@ def jacobians(self, x): output function `(batch, outputs)` """ model = extend(self.model) - to_stack = [] + P = len(parameters_to_vector(model.parameters()).detach()) + Jks = torch.empty(model.output_size, x.shape[0], P, device=x.device) # (C, b, P) for i in range(model.output_size): model.zero_grad() out = model(x) @@ -49,7 +50,7 @@ def jacobians(self, x): Jk = torch.cat(to_cat, dim=1) if self.subnetwork_indices is not None: Jk = Jk[:, self.subnetwork_indices] - to_stack.append(Jk) + Jks[i] = Jk if i == 0: f = out.detach() @@ -57,7 +58,7 @@ def jacobians(self, x): CTX.remove_hooks() _cleanup(model) if model.output_size > 1: - return torch.stack(to_stack, dim=2).transpose(1, 2), f + return Jks.transpose(0, 1), f else: return Jk.unsqueeze(-1).transpose(1, 2), f