From 6b392e750bf38b14553168685e4c583d0df81463 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Tue, 8 Aug 2023 11:03:56 -0400 Subject: [PATCH 1/7] Implement running NLL metric, based on torchmetrics --- laplace/utils/__init__.py | 3 ++- laplace/utils/metrics.py | 40 +++++++++++++++++++++++++++++++++++ setup.cfg | 1 + tests/test_metrics.py | 44 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 1 deletion(-) create mode 100644 laplace/utils/metrics.py create mode 100644 tests/test_metrics.py diff --git a/laplace/utils/__init__.py b/laplace/utils/__init__.py index 691350fe..cfa31e51 100644 --- a/laplace/utils/__init__.py +++ b/laplace/utils/__init__.py @@ -3,6 +3,7 @@ from laplace.utils.matrix import Kron, KronDecomposed from laplace.utils.swag import fit_diagonal_swag_var from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask +from laplace.utils.metrics import RunningNLLMetric __all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', @@ -11,4 +12,4 @@ 'Kron', 'KronDecomposed', 'fit_diagonal_swag_var', 'SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask', - 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask'] + 'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask', 'RunningNLLMetric'] diff --git a/laplace/utils/metrics.py b/laplace/utils/metrics.py new file mode 100644 index 00000000..444212de --- /dev/null +++ b/laplace/utils/metrics.py @@ -0,0 +1,40 @@ +import torch +from torch.nn import functional as F +from torchmetrics import Metric + + +class RunningNLLMetric(Metric): + """ + NLL metrics that + + Parameters + ---------- + ignore_index: int, default = -100 + which class label to ignore when computing the NLL loss + """ + def __init__(self, ignore_index=-100): + super().__init__() + self.add_state('nll_sum', default=torch.tensor(0.), dist_reduce_fx='sum') + self.add_state('n_valid_labels', default=torch.tensor(0.), dist_reduce_fx='sum') + self.ignore_index = ignore_index + + def update(self, probs: torch.Tensor, targets: torch.Tensor) -> None: + """ + Parameters + ---------- + probs: torch.Tensor + probability tensor of shape (..., n_classes) + + targets: torch.Tensor + integer tensor of shape (...) + """ + probs = probs.view(-1, probs.shape[-1]) + targets = targets.view(-1) + + self.nll_sum += F.nll_loss( + probs.log(), targets, ignore_index=self.ignore_index, reduction='sum' + ).item() + self.n_valid_labels += (targets != self.ignore_index).sum().item() + + def compute(self): + return self.nll_sum / self.n_valid_labels diff --git a/setup.cfg b/setup.cfg index 47d2dc10..b11b6694 100644 --- a/setup.cfg +++ b/setup.cfg @@ -37,6 +37,7 @@ install_requires = torchaudio backpack-for-pytorch asdl + torchmetrics # Require a specific Python version, e.g. Python 2.7 or >= 3.4 python_requires = >=3.8 diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 00000000..0a7e86f8 --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,44 @@ +import torch +from torch.nn import functional as F +from laplace.utils import RunningNLLMetric +import math + + +def test_running_nll_metric(): + metric = RunningNLLMetric() + all_probs, all_targets = [], [] + + for _ in range(10): + probs = torch.softmax(torch.randn(3, 5, 10), dim=-1) + targets = torch.randint(10, size=(3, 5)) + metric.update(probs, targets) + all_probs.append(probs) + all_targets.append(targets) + + all_probs, all_targets = torch.cat(all_probs, 0), torch.cat(all_targets, 0) + + nll_running = metric.compute().item() + nll_offline = F.nll_loss(all_probs.log().flatten(end_dim=-2), all_targets.flatten()).item() + + assert math.isclose(nll_running, nll_offline) + + +def test_running_nll_metric_ignore_idx(): + ignore_idx = -1232 + metric_orig = RunningNLLMetric() + metric_ignore = RunningNLLMetric(ignore_index=ignore_idx) + + for _ in range(10): + probs = torch.softmax(torch.randn(3, 5, 10), dim=-1) + targets_orig = torch.randint(10, size=(3, 5)) + targets_ignore = targets_orig.clone() + metric_orig.update(probs, targets_orig) + + mask = torch.FloatTensor(*targets_ignore.shape).uniform_() > 0.8 # ~80% zeros + targets_ignore[mask] = ignore_idx # ~80% changed to ignore_idx + metric_ignore.update(probs, targets_ignore) + + nll_orig = metric_orig.compute().item() + nll_ignore = metric_ignore.compute().item() + + assert nll_orig > nll_ignore From 934bb0e46c5ffcf26e6a4ce2fe1ec78bcb2ea20d Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Tue, 8 Aug 2023 11:37:41 -0400 Subject: [PATCH 2/7] Support cross-val grid search with running metrics --- laplace/baselaplace.py | 20 +++++++----------- laplace/utils/utils.py | 48 +++++++++++++++++++++++++++++------------- tests/test_utils.py | 36 +++++++++++++++++++++++++++---- 3 files changed, 73 insertions(+), 31 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 57987107..38974276 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -5,6 +5,7 @@ from torch.distributions import MultivariateNormal import tqdm from collections import UserDict +import torchmetrics as tm from laplace.utils import (parameters_per_layer, invsqrt_precision, get_nll, validate, Kron, normal_samples) @@ -221,8 +222,9 @@ def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100 initial prior precision before the first optimization step. val_loader : torch.data.utils.DataLoader, default=None DataLoader for the validation set; each iterate is a training batch (X, y). - loss : callable, default=get_nll - loss function to use for CV. + loss : callable or torchmetrics.Metric, default=get_nll + loss function to use for CV. If callable, the loss is computed offline (memory intensive). + If torchmetrics.Metric, running loss is computed (efficient). cv_loss_with_var: bool, default=False if true, `loss` takes three arguments `loss(output_mean, output_var, target)`, otherwise, `loss` takes two arguments `loss(output_mean, target)` @@ -275,24 +277,18 @@ def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100 def _gridsearch(self, loss, interval, val_loader, pred_type, link_approx='probit', n_samples=100, loss_with_var=False, progress_bar=False): + assert callable(loss) or isinstance(loss, tm.Metric) + results = list() prior_precs = list() pbar = tqdm.tqdm(interval) if progress_bar else interval for prior_prec in pbar: self.prior_precision = prior_prec try: - out_dist, targets = validate( - self, val_loader, pred_type=pred_type, + result = validate( + self, val_loader, loss, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples ) - if self.likelihood == 'regression': - out_mean, out_var = out_dist - if loss_with_var: - result = loss(out_mean, out_var, targets).item() - else: - result = loss(out_mean, targets).item() - else: - result = loss(out_dist, targets).item() except RuntimeError: result = np.inf results.append(result) diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index 6e133848..80c6da3c 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -6,6 +6,7 @@ from torch.nn.utils import parameters_to_vector from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.distributions.multivariate_normal import _precision_to_scale_tril +from torchmetrics import Metric __all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', @@ -17,10 +18,15 @@ def get_nll(out_dist, targets): @torch.no_grad() -def validate(laplace, val_loader, pred_type='glm', link_approx='probit', n_samples=100): +def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100) -> float: laplace.model.eval() - output_means, output_vars = list(), list() - targets = list() + assert callable(loss) or isinstance(loss, Metric) + is_offline = loss + + if is_offline: + output_means, output_vars = list(), list() + targets = list() + for X, y in val_loader: X, y = X.to(laplace._device), y.to(laplace._device) out = laplace( @@ -29,17 +35,29 @@ def validate(laplace, val_loader, pred_type='glm', link_approx='probit', n_sampl n_samples=n_samples) if type(out) == tuple: - output_means.append(out[0]) - output_vars.append(out[1]) + if is_offline: + output_means.append(out[0]) + output_vars.append(out[1]) + targets.append(y) + else: + loss.update(*out, y) else: - output_means.append(out) - - targets.append(y) - - if len(output_vars) == 0: - return torch.cat(output_means, dim=0), torch.cat(targets, dim=0) - return ((torch.cat(output_means, dim=0), torch.cat(output_vars, dim=0)), - torch.cat(targets, dim=0)) + if is_offline: + output_means.append(out) + targets.append(y) + else: + loss.update(out, y) + + if is_offline: + if len(output_vars) == 0: + preds, targets = torch.cat(output_means, dim=0), torch.cat(targets, dim=0) + return loss(preds, targets).item() + + means, variances = torch.cat(output_means, dim=0), torch.cat(output_vars, dim=0) + targets = torch.cat(targets, dim=0) + return loss(means, variances, targets).item() + else: + return loss.compute().item() def parameters_per_layer(model): @@ -236,9 +254,9 @@ def normal_samples(mean, var, n_samples, generator=None): """ assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.' _, output_dim = mean.shape - randn_samples = torch.randn((output_dim, n_samples), device=mean.device, + randn_samples = torch.randn((output_dim, n_samples), device=mean.device, dtype=mean.dtype, generator=generator) - + if mean.shape == var.shape: # diagonal covariance scaled_samples = var.sqrt().unsqueeze(-1) * randn_samples.unsqueeze(0) diff --git a/tests/test_utils.py b/tests/test_utils.py index 22f01541..d6f9165c 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,8 @@ import torch -from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig, normal_samples +from torch.utils.data import TensorDataset, DataLoader +from laplace import Laplace +from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig, normal_samples, validate, get_nll, RunningNLLMetric +import math def test_sqrt_precision(): @@ -23,7 +26,7 @@ def test_symeig_custom(): l2, W2 = symeig(M) assert torch.allclose(l1, l2) assert torch.allclose(W1, W2) - + def test_symeig_custom_low_rank(): X = torch.randn(1000, 10) @@ -35,7 +38,7 @@ def test_symeig_custom_low_rank(): # test clamping to zeros assert torch.all(l2 >= 0.0) - + def test_diagonal_normal_samples(): mean = torch.randn(10, 2) var = torch.exp(torch.randn(10, 2)) @@ -48,7 +51,7 @@ def test_diagonal_normal_samples(): same_samples = normal_samples(mean, var, n_samples=100, generator=generator) assert torch.allclose(samples, same_samples) - + def test_multivariate_normal_samples(): mean = torch.randn(10, 2) rndns = torch.randn(10, 2, 10) / 100 @@ -61,3 +64,28 @@ def test_multivariate_normal_samples(): generator.set_state(gen_state) same_samples = normal_samples(mean, var, n_samples=100, generator=generator) assert torch.allclose(samples, same_samples) + + +def test_validate(): + X = torch.randn(50, 10) + y = torch.randint(3, size=(50,)) + dataloader = DataLoader(TensorDataset(X, y), batch_size=10) + + model = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3)) + la = Laplace(model, 'classification', 'all') + la.fit(dataloader) + + res = validate( + la, dataloader, get_nll, pred_type='nn', link_approx='mc', n_samples=10 + ) + assert res != math.nan + assert isinstance(res, float) + assert res > 0 + + res = validate( + la, dataloader, RunningNLLMetric(), pred_type='nn', link_approx='mc', n_samples=10 + ) + assert res != math.nan + assert isinstance(res, float) + assert res > 0 + From 26f4af471abbd63322b3d2e7b29795c89594de6c Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Tue, 8 Aug 2023 16:48:28 -0400 Subject: [PATCH 3/7] Support HF datasets in validate --- laplace/utils/utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index 80c6da3c..fef9dc64 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -7,6 +7,7 @@ from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d from torch.distributions.multivariate_normal import _precision_to_scale_tril from torchmetrics import Metric +from collections import UserDict __all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', @@ -27,7 +28,8 @@ def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n output_means, output_vars = list(), list() targets = list() - for X, y in val_loader: + for data in val_loader: + X, y = (data['input_ids'], data['labels']) if isinstance(data, UserDict) else data X, y = X.to(laplace._device), y.to(laplace._device) out = laplace( X, pred_type=pred_type, From d72acc8df1556cbcb3f8fbe7b318d23501800578 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Wed, 9 Aug 2023 11:06:23 -0400 Subject: [PATCH 4/7] Fix progress bar --- laplace/baselaplace.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 17daae58..09fbcfaf 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -269,7 +269,8 @@ def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100 ) self.prior_precision = self._gridsearch( loss, interval, val_loader, pred_type=pred_type, - link_approx=link_approx, n_samples=n_samples, loss_with_var=cv_loss_with_var + link_approx=link_approx, n_samples=n_samples, loss_with_var=cv_loss_with_var, + progress_bar=progress_bar ) else: raise ValueError('For now only marglik and CV is implemented.') From 2365ca32f6449f509eaa14e1337263ad0a2a3624 Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Fri, 11 Aug 2023 15:53:39 -0400 Subject: [PATCH 5/7] Fix flag --- laplace/baselaplace.py | 7 ++++++- laplace/utils/utils.py | 5 +++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/laplace/baselaplace.py b/laplace/baselaplace.py index 09fbcfaf..9103e14b 100644 --- a/laplace/baselaplace.py +++ b/laplace/baselaplace.py @@ -290,10 +290,15 @@ def _gridsearch(self, loss, interval, val_loader, pred_type, try: result = validate( self, val_loader, loss, pred_type=pred_type, - link_approx=link_approx, n_samples=n_samples + link_approx=link_approx, n_samples=n_samples, + loss_with_var=loss_with_var ) except RuntimeError: result = np.inf + + if progress_bar: + pbar.set_description(f'[prior_prec: {prior_prec:.3e}, loss: {result:.3f}]') + results.append(result) prior_precs.append(prior_prec) return prior_precs[np.argmin(results)] diff --git a/laplace/utils/utils.py b/laplace/utils/utils.py index fef9dc64..67c6ce0b 100644 --- a/laplace/utils/utils.py +++ b/laplace/utils/utils.py @@ -8,6 +8,7 @@ from torch.distributions.multivariate_normal import _precision_to_scale_tril from torchmetrics import Metric from collections import UserDict +import math __all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron', @@ -19,10 +20,10 @@ def get_nll(out_dist, targets): @torch.no_grad() -def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100) -> float: +def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100, loss_with_var=False) -> float: laplace.model.eval() assert callable(loss) or isinstance(loss, Metric) - is_offline = loss + is_offline = not isinstance(loss, Metric) if is_offline: output_means, output_vars = list(), list() From 1568200ec2092ebfe7d8e28680f0e1cfc24eb0fd Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Wed, 16 Aug 2023 11:49:41 -0400 Subject: [PATCH 6/7] Remove item --- laplace/utils/metrics.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/laplace/utils/metrics.py b/laplace/utils/metrics.py index 444212de..6a29d3e7 100644 --- a/laplace/utils/metrics.py +++ b/laplace/utils/metrics.py @@ -33,8 +33,8 @@ def update(self, probs: torch.Tensor, targets: torch.Tensor) -> None: self.nll_sum += F.nll_loss( probs.log(), targets, ignore_index=self.ignore_index, reduction='sum' - ).item() - self.n_valid_labels += (targets != self.ignore_index).sum().item() + ) + self.n_valid_labels += (targets != self.ignore_index).sum() def compute(self): return self.nll_sum / self.n_valid_labels From 0c663e999cd9433d87ba5e03b910651a05893dce Mon Sep 17 00:00:00 2001 From: Agustinus Kristiadi Date: Wed, 16 Aug 2023 14:21:07 -0400 Subject: [PATCH 7/7] Fix test running metrics --- tests/test_metrics.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index 0a7e86f8..2edb72e0 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -20,25 +20,28 @@ def test_running_nll_metric(): nll_running = metric.compute().item() nll_offline = F.nll_loss(all_probs.log().flatten(end_dim=-2), all_targets.flatten()).item() - assert math.isclose(nll_running, nll_offline) + assert math.isclose(nll_running, nll_offline, rel_tol=1e-7) def test_running_nll_metric_ignore_idx(): ignore_idx = -1232 - metric_orig = RunningNLLMetric() - metric_ignore = RunningNLLMetric(ignore_index=ignore_idx) + metric = RunningNLLMetric(ignore_index=ignore_idx) + all_probs, all_targets = [], [] for _ in range(10): probs = torch.softmax(torch.randn(3, 5, 10), dim=-1) - targets_orig = torch.randint(10, size=(3, 5)) - targets_ignore = targets_orig.clone() - metric_orig.update(probs, targets_orig) + targets = torch.randint(10, size=(3, 5)) + mask = torch.FloatTensor(*targets.shape).uniform_() > 0.5 # ~50% zeros + targets[mask] = ignore_idx # ~50% changed to ignore_idx + all_probs.append(probs) + all_targets.append(targets) + metric.update(probs, targets) + + all_probs, all_targets = torch.cat(all_probs, 0), torch.cat(all_targets, 0) - mask = torch.FloatTensor(*targets_ignore.shape).uniform_() > 0.8 # ~80% zeros - targets_ignore[mask] = ignore_idx # ~80% changed to ignore_idx - metric_ignore.update(probs, targets_ignore) + nll_running = metric.compute().item() + nll_offline = F.nll_loss(all_probs.log().flatten(end_dim=-2), all_targets.flatten(), ignore_index=ignore_idx).item() - nll_orig = metric_orig.compute().item() - nll_ignore = metric_ignore.compute().item() + print(nll_running, nll_offline) - assert nll_orig > nll_ignore + assert math.isclose(nll_running, nll_offline, rel_tol=1e-7)