Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Curvlinops backend & add default functorch implementations of many curvature quantities #146

Merged
merged 19 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 10 additions & 7 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ There is also a corresponding paper, [*Laplace Redux — Effortless Bayesian Dee
```bibtex
@inproceedings{laplace2021,
title={Laplace Redux--Effortless {B}ayesian Deep Learning},
author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
author={Erik Daxberger and Agustinus Kristiadi and Alexander Immer
and Runa Eschenhagen and Matthias Bauer and Philipp Hennig},
booktitle={{N}eur{IPS}},
year={2021}
Expand All @@ -23,6 +23,7 @@ The [code](https://github.com/runame/laplace-redux) to reproduce the experiments
## Setup

We assume `python3.8` since the package was developed with that version.
PyTorch version 2.0 and up is also required for full compatibility.
To install laplace with `pip`, run the following:
```bash
pip install laplace-torch
Expand Down Expand Up @@ -60,10 +61,12 @@ One can also implement custom subnetwork selection strategies as new subclasses

Alternatively, extending or integrating backends (subclasses of [`curvature.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvature.py)) allows to provide different Hessian
approximations to the Laplace approximations.
For example, currently the [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.
The `curvature.AsdlInterface` provides a Kronecker factored empirical Fisher while the `curvature.BackPackInterface`
does not, and only the `curvature.BackPackInterface` provides access to Hessian approximations
for a regression (MSELoss) loss function.
For example, currently the [`curvature.CurvlinopsInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/curvlinops.py) based on [Curvlinops](https://github.com/f-dangel/curvlinops) and the native `torch.func` (previously known as `functorch`), [`curvature.BackPackInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/backpack.py) based on [BackPACK](https://github.com/f-dangel/backpack/) and [`curvature.AsdlInterface`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/asdl.py) based on [ASDL](https://github.com/kazukiosawa/asdfghjkl) are available.

The `curvature.CurvlinopsInterface` backend is the default and provides all Hessian approximation variants except the low-rank Hessian.
For the latter, `curvature.AsdlInterface` can be used.
Note that `curvature.AsdlInterface` and `curvature.BackPackInterface` are less complete and less compatible than `curvature.CurvlinopsInterface`.
So, we recommend to stick with `curvature.CurvlinopsInterface` unless you have a specific need of ASDL or BackPACK.

## Example usage

Expand All @@ -80,7 +83,7 @@ the `'probit'` predictive for classification.
from laplace import Laplace

# Pre-trained model
model = load_map_model()
model = load_map_model()

# User-specified LA flavor
la = Laplace(model, 'classification',
Expand All @@ -104,7 +107,7 @@ the log marginal likelihood.
from laplace import Laplace

# Un- or pre-trained model
model = load_model()
model = load_model()

# Default to recommended last-layer KFAC LA:
la = Laplace(model, likelihood='regression')
Expand Down
50 changes: 24 additions & 26 deletions laplace/baselaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import numpy as np
import torch
from torch.nn.utils import parameters_to_vector, vector_to_parameters
from torch.distributions import MultivariateNormal, Dirichlet, Normal
from torch.distributions import MultivariateNormal

from laplace.utils import (parameters_per_layer, invsqrt_precision,
from laplace.utils import (parameters_per_layer, invsqrt_precision,
get_nll, validate, Kron, normal_samples,
fix_prior_prec_structure)
from laplace.curvature import AsdlGGN, BackPackGGN, AsdlHessian
from laplace.curvature import AsdlHessian, CurvlinopsGGN


__all__ = ['BaseLaplace', 'ParametricLaplace',
Expand Down Expand Up @@ -36,7 +36,7 @@ class BaseLaplace:
whether to enable backprop to the input `x` through the Laplace predictive.
Useful for e.g. Bayesian optimization.
backend : subclasses of `laplace.curvature.CurvatureInterface`
backend for access to curvature/Hessian approximations
backend for access to curvature/Hessian approximations. Defaults to CurvlinopsGGN if None.
backend_kwargs : dict, default=None
arguments passed to the backend on initialization, for example to
set the number of MC samples for stochastic approximations.
Expand All @@ -61,10 +61,8 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
self.temperature = temperature
self.enable_backprop = enable_backprop

if backend is None:
backend = AsdlGGN if likelihood == 'classification' else BackPackGGN
self._backend = None
self._backend_cls = backend
self._backend_cls = backend if backend is not None else CurvlinopsGGN
self._backend_kwargs = dict() if backend_kwargs is None else backend_kwargs

# log likelihood = g(loss)
Expand Down Expand Up @@ -357,7 +355,7 @@ class ParametricLaplace(BaseLaplace):
"""

def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., enable_backprop=False,
prior_mean=0., temperature=1., enable_backprop=False,
backend=None, backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise, prior_precision,
prior_mean, temperature, enable_backprop, backend, backend_kwargs)
Expand Down Expand Up @@ -527,7 +525,7 @@ def log_marginal_likelihood(self, prior_precision=None, sigma_noise=None):

return self.log_likelihood - 0.5 * (self.log_det_ratio + self.scatter)

def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',
def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',
n_samples=100, diagonal_output=False, generator=None):
"""Compute the posterior predictive on input data `x`.

Expand All @@ -543,13 +541,13 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',

link_approx : {'mc', 'probit', 'bridge', 'bridge_norm'}
how to approximate the classification link function for the `'glm'`.
For `pred_type='nn'`, only 'mc' is possible.
For `pred_type='nn'`, only 'mc' is possible.

joint : bool
Whether to output a joint predictive distribution in regression with
`pred_type='glm'`. If set to `True`, the predictive distribution
has the same form as GP posterior, i.e. N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
If `False`, then only outputs the marginal predictive distribution.
If `False`, then only outputs the marginal predictive distribution.
Only available for regression and GLM predictive.

n_samples : int
Expand All @@ -569,8 +567,8 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',
a distribution over classes (similar to a Softmax).
For `likelihood='regression'`, a tuple of torch.Tensor is returned
with the mean and the predictive variance.
For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
is returned with the mean and the predictive covariance.
For `likelihood='regression'` and `joint=True`, a tuple of torch.Tensor
is returned with the mean and the predictive covariance.
"""
if pred_type not in ['glm', 'nn']:
raise ValueError('Only glm and nn supported as prediction types.')
Expand All @@ -580,7 +578,7 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',

if pred_type == 'nn' and link_approx != 'mc':
raise ValueError('Only mc link approximation is supported for nn prediction type.')

if generator is not None:
if not isinstance(generator, torch.Generator) or generator.device != x.device:
raise ValueError('Invalid random generator (check type and device).')
Expand All @@ -594,7 +592,7 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',
return f_mu, f_var
# classification
if link_approx == 'mc':
return self.predictive_samples(x, pred_type='glm', n_samples=n_samples,
return self.predictive_samples(x, pred_type='glm', n_samples=n_samples,
diagonal_output=diagonal_output).mean(dim=0)
elif link_approx == 'probit':
kappa = 1 / torch.sqrt(1. + np.pi / 8 * f_var.diagonal(dim1=1, dim2=2))
Expand Down Expand Up @@ -623,7 +621,7 @@ def __call__(self, x, pred_type='glm', joint=False, link_approx='probit',
return samples.mean(dim=0), samples.var(dim=0)
return samples.mean(dim=0)

def predictive_samples(self, x, pred_type='glm', n_samples=100,
def predictive_samples(self, x, pred_type='glm', n_samples=100,
diagonal_output=False, generator=None):
"""Sample from the posterior predictive on input data `x`.
Can be used, for example, for Thompson sampling.
Expand Down Expand Up @@ -720,7 +718,7 @@ def functional_covariance(self, Jacs):
`f_cov = Jacs @ P.inv() @ Jacs.T`, which is a batch*output x batch*output
predictive covariance matrix.

This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
This emulates the GP posterior covariance N([f(x1), ...,f(xm)], Cov[f(x1), ..., f(xm)]).
Useful for joint predictions, such as in batched Bayesian optimization.

Parameters
Expand Down Expand Up @@ -875,7 +873,7 @@ class KronLaplace(ParametricLaplace):
_key = ('all', 'kron')

def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
prior_mean=0., temperature=1., enable_backprop=False, backend=None,
prior_mean=0., temperature=1., enable_backprop=False, backend=None,
damping=False, **backend_kwargs):
self.damping = damping
self.H_facs = None
Expand Down Expand Up @@ -965,26 +963,26 @@ def prior_precision(self, prior_precision):


class LowRankLaplace(ParametricLaplace):
"""Laplace approximation with low-rank log likelihood Hessian (approximation).
"""Laplace approximation with low-rank log likelihood Hessian (approximation).
The low-rank matrix is represented by an eigendecomposition (vecs, values).
Based on the chosen `backend`, either a true Hessian or, for example, GGN
approximation could be used.
The posterior precision is computed as
\\( P = V diag(l) V^T + P_0.\\)
To sample, compute the functional variance, and log determinant, algebraic tricks
To sample, compute the functional variance, and log determinant, algebraic tricks
are usedto reduce the costs of inversion to the that of a \\(K \times K\\) matrix
if we have a rank of K.

See `BaseLaplace` for the full interface.
"""
_key = ('all', 'lowrank')
def __init__(self, model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0,
def __init__(self, model, likelihood, sigma_noise=1, prior_precision=1, prior_mean=0,
temperature=1, enable_backprop=False, backend=AsdlHessian, backend_kwargs=None):
super().__init__(model, likelihood, sigma_noise=sigma_noise,
prior_precision=prior_precision, prior_mean=prior_mean,
temperature=temperature, enable_backprop=enable_backprop,
super().__init__(model, likelihood, sigma_noise=sigma_noise,
prior_precision=prior_precision, prior_mean=prior_mean,
temperature=temperature, enable_backprop=enable_backprop,
backend=backend, backend_kwargs=backend_kwargs)

def _init_H(self):
self.H = None

Expand Down
8 changes: 7 additions & 1 deletion laplace/curvature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
except ModuleNotFoundError:
logging.info('asdfghjkl backend not available.')

try:
from laplace.curvature.curvlinops import CurvlinopsHessian, CurvlinopsGGN, CurvlinopsEF, CurvlinopsInterface
except ModuleNotFoundError:
logging.info('curvlinops backend not available.')

__all__ = ['CurvatureInterface', 'GGNInterface', 'EFInterface',
'BackPackInterface', 'BackPackGGN', 'BackPackEF',
'AsdlInterface', 'AsdlGGN', 'AsdlEF', 'AsdlHessian']
'AsdlInterface', 'AsdlGGN', 'AsdlEF', 'AsdlHessian',
'CurvlinopsInterface', 'CurvlinopsGGN', 'CurvlinopsEF', 'CurvlinopsHessian']
4 changes: 2 additions & 2 deletions laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def _get_kron_factors(self, curv, M):
kfacs.append([stats.kron.B, stats.kron.A[:-1, :-1]])
kfacs.append([stats.kron.B * stats.kron.A[-1, -1] / M])
elif hasattr(module, 'weight'):
p, q = np.prod(stats.kron.B.shape), np.prod(stats.kron.A.shape)
p, q = stats.kron.B.numel(), stats.kron.A.numel()
if p == q == 1:
kfacs.append([stats.kron.B * stats.kron.A])
else:
Expand Down Expand Up @@ -121,7 +121,7 @@ def diag(self, X, y, **kwargs):
diag_ggn = diag_ggn[self.subnetwork_indices]
return self.factor * loss, self.factor * diag_ggn

def kron(self, X, y, N, **wkwargs):
def kron(self, X, y, N, **kwargs):
with torch.no_grad():
if self.last_layer:
f, X = self.model.forward_with_features(X)
Expand Down
15 changes: 9 additions & 6 deletions laplace/curvature/backpack.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Tuple
import torch

from backpack import backpack, extend, memory_cleanup
Expand All @@ -18,7 +19,8 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)

def jacobians(self, x, enable_backprop=False):
"""Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
using backpack's BatchGrad per output dimension.
using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well
with torch.func, so this method has to be overridden.

Parameters
----------
Expand All @@ -42,12 +44,12 @@ def jacobians(self, x, enable_backprop=False):
with backpack(BatchGrad()):
if model.output_size > 1:
out[:, i].sum().backward(
create_graph=enable_backprop,
create_graph=enable_backprop,
retain_graph=enable_backprop
)
else:
out.sum().backward(
create_graph=enable_backprop,
create_graph=enable_backprop,
retain_graph=enable_backprop
)
to_cat = []
Expand All @@ -71,7 +73,8 @@ def jacobians(self, x, enable_backprop=False):

def gradients(self, x, y):
"""Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
\\(\\theta\\) using Backpack's BatchGrad.
\\(\\theta\\) using Backpack's BatchGrad. Note that BackPACK doesn't play well
with torch.func, so this method has to be overridden.

Parameters
----------
Expand All @@ -81,9 +84,9 @@ def gradients(self, x, y):

Returns
-------
loss : torch.Tensor
Gs : torch.Tensor
gradients `(batch, parameters)`
loss : torch.Tensor
"""
f = self.model(x)
loss = self.lossfunc(f, y)
Expand Down Expand Up @@ -136,7 +139,7 @@ def diag(self, X, y, **kwargs):

return self.factor * loss.detach(), self.factor * dggn

def kron(self, X, y, N, **kwargs) -> [torch.Tensor, Kron]:
def kron(self, X, y, N, **kwargs) -> Tuple[torch.Tensor, Kron]:
context = KFAC if self.stochastic else KFLR
f = self.model(X)
loss = self.lossfunc(f, y)
Expand Down
Loading