Skip to content

Commit

Permalink
Merge pull request #135 from AlexImmer/running-metrics
Browse files Browse the repository at this point in the history
Running metrics
  • Loading branch information
runame authored Aug 16, 2023
2 parents c3004c5 + 0c663e9 commit 05bbdfb
Show file tree
Hide file tree
Showing 7 changed files with 175 additions and 35 deletions.
30 changes: 16 additions & 14 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from torch.distributions import MultivariateNormal
import tqdm
from collections import UserDict
import torchmetrics as tm

from laplace.utils import (parameters_per_layer, invsqrt_precision,
get_nll, validate, Kron, normal_samples)
Expand Down Expand Up @@ -223,8 +224,9 @@ def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100
initial prior precision before the first optimization step.
val_loader : torch.data.utils.DataLoader, default=None
DataLoader for the validation set; each iterate is a training batch (X, y).
loss : callable, default=get_nll
loss function to use for CV.
loss : callable or torchmetrics.Metric, default=get_nll
loss function to use for CV. If callable, the loss is computed offline (memory intensive).
If torchmetrics.Metric, running loss is computed (efficient).
cv_loss_with_var: bool, default=False
if true, `loss` takes three arguments `loss(output_mean, output_var, target)`,
otherwise, `loss` takes two arguments `loss(output_mean, target)`
Expand Down Expand Up @@ -267,7 +269,8 @@ def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100
)
self.prior_precision = self._gridsearch(
loss, interval, val_loader, pred_type=pred_type,
link_approx=link_approx, n_samples=n_samples, loss_with_var=cv_loss_with_var
link_approx=link_approx, n_samples=n_samples, loss_with_var=cv_loss_with_var,
progress_bar=progress_bar
)
else:
raise ValueError('For now only marglik and CV is implemented.')
Expand All @@ -277,26 +280,25 @@ def optimize_prior_precision_base(self, pred_type, method='marglik', n_steps=100
def _gridsearch(self, loss, interval, val_loader, pred_type,
link_approx='probit', n_samples=100, loss_with_var=False,
progress_bar=False):
assert callable(loss) or isinstance(loss, tm.Metric)

results = list()
prior_precs = list()
pbar = tqdm.tqdm(interval) if progress_bar else interval
for prior_prec in pbar:
self.prior_precision = prior_prec
try:
out_dist, targets = validate(
self, val_loader, pred_type=pred_type,
link_approx=link_approx, n_samples=n_samples
result = validate(
self, val_loader, loss, pred_type=pred_type,
link_approx=link_approx, n_samples=n_samples,
loss_with_var=loss_with_var
)
if self.likelihood == 'regression':
out_mean, out_var = out_dist
if loss_with_var:
result = loss(out_mean, out_var, targets).item()
else:
result = loss(out_mean, targets).item()
else:
result = loss(out_dist, targets).item()
except RuntimeError:
result = np.inf

if progress_bar:
pbar.set_description(f'[prior_prec: {prior_prec:.3e}, loss: {result:.3f}]')

results.append(result)
prior_precs.append(prior_prec)
return prior_precs[np.argmin(results)]
Expand Down
3 changes: 2 additions & 1 deletion laplace/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from laplace.utils.matrix import Kron, KronDecomposed
from laplace.utils.swag import fit_diagonal_swag_var
from laplace.utils.subnetmask import SubnetMask, RandomSubnetMask, LargestMagnitudeSubnetMask, LargestVarianceDiagLaplaceSubnetMask, LargestVarianceSWAGSubnetMask, ParamNameSubnetMask, ModuleNameSubnetMask, LastLayerSubnetMask
from laplace.utils.metrics import RunningNLLMetric


__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron',
Expand All @@ -11,4 +12,4 @@
'Kron', 'KronDecomposed',
'fit_diagonal_swag_var',
'SubnetMask', 'RandomSubnetMask', 'LargestMagnitudeSubnetMask', 'LargestVarianceDiagLaplaceSubnetMask',
'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask']
'LargestVarianceSWAGSubnetMask', 'ParamNameSubnetMask', 'ModuleNameSubnetMask', 'LastLayerSubnetMask', 'RunningNLLMetric']
40 changes: 40 additions & 0 deletions laplace/utils/metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
from torch.nn import functional as F
from torchmetrics import Metric


class RunningNLLMetric(Metric):
"""
NLL metrics that
Parameters
----------
ignore_index: int, default = -100
which class label to ignore when computing the NLL loss
"""
def __init__(self, ignore_index=-100):
super().__init__()
self.add_state('nll_sum', default=torch.tensor(0.), dist_reduce_fx='sum')
self.add_state('n_valid_labels', default=torch.tensor(0.), dist_reduce_fx='sum')
self.ignore_index = ignore_index

def update(self, probs: torch.Tensor, targets: torch.Tensor) -> None:
"""
Parameters
----------
probs: torch.Tensor
probability tensor of shape (..., n_classes)
targets: torch.Tensor
integer tensor of shape (...)
"""
probs = probs.view(-1, probs.shape[-1])
targets = targets.view(-1)

self.nll_sum += F.nll_loss(
probs.log(), targets, ignore_index=self.ignore_index, reduction='sum'
)
self.n_valid_labels += (targets != self.ignore_index).sum()

def compute(self):
return self.nll_sum / self.n_valid_labels
53 changes: 37 additions & 16 deletions laplace/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@
from torch.nn.utils import parameters_to_vector
from torch.nn import BatchNorm1d, BatchNorm2d, BatchNorm3d
from torch.distributions.multivariate_normal import _precision_to_scale_tril
from torchmetrics import Metric
from collections import UserDict
import math


__all__ = ['get_nll', 'validate', 'parameters_per_layer', 'invsqrt_precision', 'kron',
Expand All @@ -17,29 +20,47 @@ def get_nll(out_dist, targets):


@torch.no_grad()
def validate(laplace, val_loader, pred_type='glm', link_approx='probit', n_samples=100):
def validate(laplace, val_loader, loss, pred_type='glm', link_approx='probit', n_samples=100, loss_with_var=False) -> float:
laplace.model.eval()
output_means, output_vars = list(), list()
targets = list()
for X, y in val_loader:
assert callable(loss) or isinstance(loss, Metric)
is_offline = not isinstance(loss, Metric)

if is_offline:
output_means, output_vars = list(), list()
targets = list()

for data in val_loader:
X, y = (data['input_ids'], data['labels']) if isinstance(data, UserDict) else data
X, y = X.to(laplace._device), y.to(laplace._device)
out = laplace(
X, pred_type=pred_type,
link_approx=link_approx,
n_samples=n_samples)

if type(out) == tuple:
output_means.append(out[0])
output_vars.append(out[1])
if is_offline:
output_means.append(out[0])
output_vars.append(out[1])
targets.append(y)
else:
loss.update(*out, y)
else:
output_means.append(out)

targets.append(y)

if len(output_vars) == 0:
return torch.cat(output_means, dim=0), torch.cat(targets, dim=0)
return ((torch.cat(output_means, dim=0), torch.cat(output_vars, dim=0)),
torch.cat(targets, dim=0))
if is_offline:
output_means.append(out)
targets.append(y)
else:
loss.update(out, y)

if is_offline:
if len(output_vars) == 0:
preds, targets = torch.cat(output_means, dim=0), torch.cat(targets, dim=0)
return loss(preds, targets).item()

means, variances = torch.cat(output_means, dim=0), torch.cat(output_vars, dim=0)
targets = torch.cat(targets, dim=0)
return loss(means, variances, targets).item()
else:
return loss.compute().item()


def parameters_per_layer(model):
Expand Down Expand Up @@ -236,9 +257,9 @@ def normal_samples(mean, var, n_samples, generator=None):
"""
assert mean.ndim == 2, 'Invalid input shape of mean, should be 2-dimensional.'
_, output_dim = mean.shape
randn_samples = torch.randn((output_dim, n_samples), device=mean.device,
randn_samples = torch.randn((output_dim, n_samples), device=mean.device,
dtype=mean.dtype, generator=generator)

if mean.shape == var.shape:
# diagonal covariance
scaled_samples = var.sqrt().unsqueeze(-1) * randn_samples.unsqueeze(0)
Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ install_requires =
torchaudio
backpack-for-pytorch
asdl
torchmetrics
# Require a specific Python version, e.g. Python 2.7 or >= 3.4
python_requires = >=3.8

Expand Down
47 changes: 47 additions & 0 deletions tests/test_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import torch
from torch.nn import functional as F
from laplace.utils import RunningNLLMetric
import math


def test_running_nll_metric():
metric = RunningNLLMetric()
all_probs, all_targets = [], []

for _ in range(10):
probs = torch.softmax(torch.randn(3, 5, 10), dim=-1)
targets = torch.randint(10, size=(3, 5))
metric.update(probs, targets)
all_probs.append(probs)
all_targets.append(targets)

all_probs, all_targets = torch.cat(all_probs, 0), torch.cat(all_targets, 0)

nll_running = metric.compute().item()
nll_offline = F.nll_loss(all_probs.log().flatten(end_dim=-2), all_targets.flatten()).item()

assert math.isclose(nll_running, nll_offline, rel_tol=1e-7)


def test_running_nll_metric_ignore_idx():
ignore_idx = -1232
metric = RunningNLLMetric(ignore_index=ignore_idx)
all_probs, all_targets = [], []

for _ in range(10):
probs = torch.softmax(torch.randn(3, 5, 10), dim=-1)
targets = torch.randint(10, size=(3, 5))
mask = torch.FloatTensor(*targets.shape).uniform_() > 0.5 # ~50% zeros
targets[mask] = ignore_idx # ~50% changed to ignore_idx
all_probs.append(probs)
all_targets.append(targets)
metric.update(probs, targets)

all_probs, all_targets = torch.cat(all_probs, 0), torch.cat(all_targets, 0)

nll_running = metric.compute().item()
nll_offline = F.nll_loss(all_probs.log().flatten(end_dim=-2), all_targets.flatten(), ignore_index=ignore_idx).item()

print(nll_running, nll_offline)

assert math.isclose(nll_running, nll_offline, rel_tol=1e-7)
36 changes: 32 additions & 4 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import torch
from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig, normal_samples
from torch.utils.data import TensorDataset, DataLoader
from laplace import Laplace
from laplace.utils import invsqrt_precision, diagonal_add_scalar, symeig, normal_samples, validate, get_nll, RunningNLLMetric
import math


def test_sqrt_precision():
Expand All @@ -23,7 +26,7 @@ def test_symeig_custom():
l2, W2 = symeig(M)
assert torch.allclose(l1, l2)
assert torch.allclose(W1, W2)


def test_symeig_custom_low_rank():
X = torch.randn(1000, 10)
Expand All @@ -35,7 +38,7 @@ def test_symeig_custom_low_rank():
# test clamping to zeros
assert torch.all(l2 >= 0.0)


def test_diagonal_normal_samples():
mean = torch.randn(10, 2)
var = torch.exp(torch.randn(10, 2))
Expand All @@ -48,7 +51,7 @@ def test_diagonal_normal_samples():
same_samples = normal_samples(mean, var, n_samples=100, generator=generator)
assert torch.allclose(samples, same_samples)


def test_multivariate_normal_samples():
mean = torch.randn(10, 2)
rndns = torch.randn(10, 2, 10) / 100
Expand All @@ -61,3 +64,28 @@ def test_multivariate_normal_samples():
generator.set_state(gen_state)
same_samples = normal_samples(mean, var, n_samples=100, generator=generator)
assert torch.allclose(samples, same_samples)


def test_validate():
X = torch.randn(50, 10)
y = torch.randint(3, size=(50,))
dataloader = DataLoader(TensorDataset(X, y), batch_size=10)

model = torch.nn.Sequential(torch.nn.Linear(10, 20), torch.nn.ReLU(), torch.nn.Linear(20, 3))
la = Laplace(model, 'classification', 'all')
la.fit(dataloader)

res = validate(
la, dataloader, get_nll, pred_type='nn', link_approx='mc', n_samples=10
)
assert res != math.nan
assert isinstance(res, float)
assert res > 0

res = validate(
la, dataloader, RunningNLLMetric(), pred_type='nn', link_approx='mc', n_samples=10
)
assert res != math.nan
assert isinstance(res, float)
assert res > 0

0 comments on commit 05bbdfb

Please sign in to comment.