Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove try-except from gridsearch #199

Merged
merged 9 commits into from
Jun 23, 2024
41 changes: 35 additions & 6 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import torch
import torchmetrics as tm
import tqdm
from torch.linalg import LinAlgError
from torch.nn.utils import parameters_to_vector, vector_to_parameters

from laplace.curvature import CurvlinopsGGN
Expand Down Expand Up @@ -148,6 +149,11 @@ def __init__(
self.n_outputs = None
self.n_data = 0

# Useful for reward modeling where it behaves like classification during Hessian
# computation and prior-prec optimization, and behaves like regression during
# prediction
self._fitting: bool = True

@property
def _device(self):
return next(self.model.parameters()).device
Expand Down Expand Up @@ -356,6 +362,9 @@ def optimize_prior_precision_base(
whether to show a progress bar; updated at every batch-Hessian computation.
Useful for very large model and large amount of data, esp. when `subset_of_weights='all'`.
"""
if self.reward_modeling:
self.likelihood = "classification"

if method == "marglik":
self.prior_precision = init_prior_prec
if len(self.prior_precision) == 1 and prior_structure != "scalar":
Expand Down Expand Up @@ -393,9 +402,9 @@ def optimize_prior_precision_base(

if loss is None:
loss = (
tm.MeanSquaredError(num_outputs=self.n_outputs)
tm.MeanSquaredError(num_outputs=self.n_outputs).to(self._device)
if self.likelihood == "regression"
else RunningNLLMetric()
else RunningNLLMetric().to(self._device)
)

self.prior_precision = self._gridsearch(
Expand All @@ -410,6 +419,7 @@ def optimize_prior_precision_base(
)
else:
raise ValueError("For now only marglik and gridsearch is implemented.")

if verbose:
print(f"Optimized prior precision is {self.prior_precision}.")

Expand All @@ -429,8 +439,10 @@ def _gridsearch(
results = list()
prior_precs = list()
pbar = tqdm.tqdm(interval) if progress_bar else interval

for prior_prec in pbar:
self.prior_precision = prior_prec

try:
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
result = validate(
self,
Expand All @@ -442,8 +454,13 @@ def _gridsearch(
loss_with_var=loss_with_var,
dict_key_y=self.dict_key_y,
)
except RuntimeError:
except LinAlgError:
runame marked this conversation as resolved.
Show resolved Hide resolved
result = np.inf
except RuntimeError as err:
if "not positive definite" in str(err):
result = np.inf
else:
raise err

if progress_bar:
pbar.set_description(
Expand Down Expand Up @@ -568,6 +585,9 @@ def fit(self, train_loader, override=True, progress_bar=False):
self.loss = 0
self.n_data = 0

if self.reward_modeling:
self.likelihood = "classification"

self.model.eval()

self.mean = parameters_to_vector(self.params)
Expand Down Expand Up @@ -617,6 +637,7 @@ def fit(self, train_loader, override=True, progress_bar=False):
self.H += H_batch

self.n_data += N
self._fitting = False

@property
def scatter(self):
Expand Down Expand Up @@ -741,6 +762,7 @@ def __call__(
n_samples=100,
diagonal_output=False,
generator=None,
fitting=False,
**model_kwargs,
):
"""Compute the posterior predictive on input data `x`.
Expand Down Expand Up @@ -779,6 +801,11 @@ def __call__(
generator : torch.Generator, optional
random number generator to control the samples (if sampling used).

fitting : bool, default=False
whether or not this predictive call is done during fitting. Only useful for
reward modeling: the likelihood is set to `"regression"` when `False` and
`"classification"` when `True`.

Returns
-------
predictive: torch.Tensor or Tuple[torch.Tensor]
Expand Down Expand Up @@ -807,9 +834,8 @@ def __call__(
):
raise ValueError("Invalid random generator (check type and device).")

# For reward modeling, replace the likelihood to regression
if self.reward_modeling and self.likelihood == "classification":
self.likelihood = "regression"
if self.reward_modeling:
self.likelihood = "classification" if fitting else "regression"

if pred_type == "glm":
f_mu, f_var = self._glm_predictive_distribution(
Expand Down Expand Up @@ -1463,6 +1489,9 @@ def fit(self, train_loader, override=True):
# LowRankLA cannot be updated since eigenvalue representation not additive
raise ValueError("LowRank LA does not support updating.")

if self.reward_modeling:
self.likelihood = "classification"

self.model.eval()
self.mean = parameters_to_vector(self.model.parameters())

Expand Down
14 changes: 11 additions & 3 deletions laplace/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ def validate(
X = X.to(laplace._device)
y = y.to(laplace._device)
out = laplace(
X, pred_type=pred_type, link_approx=link_approx, n_samples=n_samples
X,
pred_type=pred_type,
link_approx=link_approx,
n_samples=n_samples,
fitting=True,
)

if type(out) == tuple:
Expand All @@ -63,7 +67,10 @@ def validate(
output_vars.append(out[1])
targets.append(y)
else:
loss.update(*out, y)
try:
loss.update(*out, y)
except TypeError: # If the online loss only accepts 2 args
loss.update(out[0], y)
else:
if is_offline:
output_means.append(out)
Expand All @@ -80,7 +87,8 @@ def validate(
targets = torch.cat(targets, dim=0)
return loss(means, variances, targets).item()
else:
return loss.compute().item()
# Aggregate since torchmetrics output n_classes values for the MSE metric
return loss.compute().sum().item()


def parameters_per_layer(model):
Expand Down
24 changes: 24 additions & 0 deletions tests/test_baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,3 +739,27 @@ def test_backprop_nn(laplace, model, reg_loader, backend):
assert grad_X_var.shape == X.shape
except ValueError:
assert False


@pytest.mark.parametrize(
"likelihood", ["classification", "regression", "reward_modeling"]
)
@pytest.mark.parametrize("prior_prec_type", ["scalar", "layerwise", "diag"])
def test_gridsearch(model, likelihood, prior_prec_type, reg_loader, class_loader):
if likelihood == "regression":
dataloader = reg_loader
else:
dataloader = class_loader

if prior_prec_type == "scalar":
prior_prec = 1.0
elif prior_prec_type == "layerwise":
prior_prec = torch.ones(model.n_layers)
else:
prior_prec = torch.ones(model.n_params)

lap = DiagLaplace(model, likelihood, prior_precision=prior_prec)
lap.fit(dataloader)

# Should not raise an error
lap.optimize_prior_precision(method="gridsearch", val_loader=dataloader, n_steps=10)