Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Curvlinops backend & add default functorch implementations of many curvature quantities #146

Merged
merged 19 commits into from
Mar 9, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion laplace/curvature/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,12 @@
except ModuleNotFoundError:
logging.info('asdfghjkl backend not available.')

try:
from laplace.curvature.curvlinops import CurvlinopsHessian, CurvlinopsGGN, CurvlinopsEF, CurvlinopsInterface
except ModuleNotFoundError:
logging.info('curvlinops backend not available.')

__all__ = ['CurvatureInterface', 'GGNInterface', 'EFInterface',
'BackPackInterface', 'BackPackGGN', 'BackPackEF',
'AsdlInterface', 'AsdlGGN', 'AsdlEF', 'AsdlHessian']
'AsdlInterface', 'AsdlGGN', 'AsdlEF', 'AsdlHessian',
'CurvlinopsInterface', 'CurvlinopsGGN', 'CurvlinopsEF', 'CurvlinopsHessian']
2 changes: 1 addition & 1 deletion laplace/curvature/asdl.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def diag(self, X, y, **kwargs):
diag_ggn = diag_ggn[self.subnetwork_indices]
return self.factor * loss, self.factor * diag_ggn

def kron(self, X, y, N, **wkwargs):
def kron(self, X, y, N, **kwargs):
with torch.no_grad():
if self.last_layer:
f, X = self.model.forward_with_features(X)
Expand Down
12 changes: 7 additions & 5 deletions laplace/curvature/backpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)

def jacobians(self, x, enable_backprop=False):
"""Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\)
using backpack's BatchGrad per output dimension.
using backpack's BatchGrad per output dimension. Note that BackPACK doesn't play well
with torch.func, so this method has to be overridden.

Parameters
----------
Expand All @@ -42,12 +43,12 @@ def jacobians(self, x, enable_backprop=False):
with backpack(BatchGrad()):
if model.output_size > 1:
out[:, i].sum().backward(
create_graph=enable_backprop,
create_graph=enable_backprop,
retain_graph=enable_backprop
)
else:
out.sum().backward(
create_graph=enable_backprop,
create_graph=enable_backprop,
retain_graph=enable_backprop
)
to_cat = []
Expand All @@ -71,7 +72,8 @@ def jacobians(self, x, enable_backprop=False):

def gradients(self, x, y):
"""Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter
\\(\\theta\\) using Backpack's BatchGrad.
\\(\\theta\\) using Backpack's BatchGrad. Note that BackPACK doesn't play well
with torch.func, so this method has to be overridden.

Parameters
----------
Expand All @@ -81,9 +83,9 @@ def gradients(self, x, y):

Returns
-------
loss : torch.Tensor
Gs : torch.Tensor
gradients `(batch, parameters)`
loss : torch.Tensor
"""
f = self.model(x)
loss = self.lossfunc(f, y)
Expand Down
149 changes: 110 additions & 39 deletions laplace/curvature/curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,18 @@ def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None)
else:
self.lossfunc = CrossEntropyLoss(reduction='sum')
self.factor = 1.
self.params = [p for p in self.model.parameters() if p.requires_grad]
name_dict = {p.data_ptr(): name for name, p in self.model.named_parameters()}
self.params_dict = {name_dict[p.data_ptr()]: p for p in self.params}

@property
def _model(self):
return self.model.last_layer if self.last_layer else self.model

def jacobians(self, x, enable_backprop=False):
"""Compute Jacobians \\(\\nabla_\\theta f(x;\\theta)\\) at current parameter \\(\\theta\\).
"""
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
Compute Jacobians \\(\\nabla_{\\theta} f(x;\\theta)\\) at current parameter \\(\\theta\\),
via torch.func.

Parameters
----------
Expand All @@ -61,10 +66,26 @@ def jacobians(self, x, enable_backprop=False):
f : torch.Tensor
output function `(batch, outputs)`
"""
raise NotImplementedError
def model_fn_params_only(params_dict):
out = torch.func.functional_call(self.model, params_dict, x)
return out, out

Js, f = torch.func.jacrev(model_fn_params_only, has_aux=True)(self.params_dict)

# Concatenate over flattened parameters
Js = [
j.flatten(start_dim=-p.dim())
for j, p in zip(Js.values(), self.params_dict.values())
]
Js = torch.cat(Js, dim=-1)

if self.subnetwork_indices is not None:
Js = Js[:, :, self.subnetwork_indices]

return (Js, f) if enable_backprop else (Js.detach(), f.detach())

def last_layer_jacobians(self, x, enable_backprop=False):
"""Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
"""Compute Jacobians \\(\\nabla_{\\theta_\\textrm{last}} f(x;\\theta_\\textrm{last})\\)
only at current last-layer parameter \\(\\theta_{\\textrm{last}}\\).

Parameters
Expand Down Expand Up @@ -93,7 +114,8 @@ def last_layer_jacobians(self, x, enable_backprop=False):
return Js, f

def gradients(self, x, y):
"""Compute gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at current parameter \\(\\theta\\).
"""Compute batch gradients \\(\\nabla_\\theta \\ell(f(x;\\theta, y)\\) at
current parameter \\(\\theta\\).

Parameters
----------
Expand All @@ -103,11 +125,33 @@ def gradients(self, x, y):

Returns
-------
loss : torch.Tensor
Gs : torch.Tensor
gradients `(batch, parameters)`
loss : torch.Tensor
"""
raise NotImplementedError
def loss_n(x_n, y_n, params_dict):
"""Compute the gradient for a single sample."""
output = torch.func.functional_call(self.model, params_dict, x_n)
loss = torch.func.functional_call(self.lossfunc, {}, (output, y_n))
return loss, loss

batch_grad_fn = torch.func.vmap(torch.func.grad(loss_n, argnums=2, has_aux=True))

batch_size = x.shape[0]
params_replicated_dict = {
name: p.unsqueeze(0).expand(batch_size, *(p.dim() * [-1]))
for name, p in self.params_dict.items()
}

batch_grad, batch_loss = batch_grad_fn(x, y, params_replicated_dict)
Gs = torch.cat([bg.flatten(start_dim=1) for bg in batch_grad.values()], dim=1)

if self.subnetwork_indices is not None:
Gs = Gs[:, self.subnetwork_indices]

loss = batch_loss.sum(0)

return Gs, loss

def full(self, x, y, **kwargs):
"""Compute a dense curvature (approximation) in the form of a \\(P \\times P\\) matrix
Expand All @@ -131,7 +175,7 @@ def full(self, x, y, **kwargs):
def kron(self, x, y, **kwargs):
"""Compute a Kronecker factored curvature approximation (such as KFAC).
The approximation to \\(H\\) takes the form of two Kronecker factors \\(Q, H\\),
i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting
i.e., \\(H \\approx Q \\otimes H\\) for each Module in the neural network permitting
such curvature.
\\(Q\\) is quadratic in the input-dimension of a module \\(p_{in} \\times p_{in}\\)
and \\(H\\) in the output-dimension \\(p_{out} \\times p_{out}\\).
Expand All @@ -152,7 +196,7 @@ def kron(self, x, y, **kwargs):
raise NotImplementedError

def diag(self, x, y, **kwargs):
"""Compute a diagonal Hessian approximation to \\(H\\) and is represented as a
"""Compute a diagonal Hessian approximation to \\(H\\) and is represented as a
vector of the dimensionality of parameters \\(\\theta\\).

Parameters
Expand Down Expand Up @@ -188,38 +232,39 @@ class GGNInterface(CurvatureInterface):
to apply the Laplace approximation over
stochastic : bool, default=False
Fisher if stochastic else GGN
num_samples: int, default=100
Number of samples used to approximate the stochastic Fisher
"""
def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False):
def __init__(self, model, likelihood, last_layer=False, subnetwork_indices=None, stochastic=False, num_samples=1):
self.stochastic = stochastic
self.num_samples = num_samples
super().__init__(model, likelihood, last_layer, subnetwork_indices)

def _get_full_ggn(self, Js, f, y):
"""Compute full GGN from Jacobians.
def _sample_H_lik(self, f):
H_lik = 0

Parameters
----------
Js : torch.Tensor
Jacobians `(batch, parameters, outputs)`
f : torch.Tensor
functions `(batch, outputs)`
y : torch.Tensor
labels compatible with loss
for _ in range(self.num_samples):
if self.likelihood == 'regression':
y_sample = f + torch.randn(f.shape, device=f.device) # N(y | f, 1)
grad_sample = f - y_sample # functional MSE grad
else: # classification with softmax
y_sample = torch.distributions.Multinomial(logits=f).sample()
# First functional derivative of the loglik is p - y
p = torch.softmax(f, dim=-1)
grad_sample = p - y_sample

Returns
-------
loss : torch.Tensor
H_ggn : torch.Tensor
full GGN approximation `(parameters, parameters)`
"""
loss = self.factor * self.lossfunc(f, y)
H_lik += 1/self.num_samples * torch.einsum('bc,bk->bck', grad_sample, grad_sample)

return H_lik

def _get_exact_H_lik(self, f):
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
if self.likelihood == 'regression':
H_ggn = torch.einsum('mkp,mkq->pq', Js, Js)
return None
else:
# second derivative of log lik is diag(p) - pp^T
ps = torch.softmax(f, dim=-1)
H_lik = torch.diag_embed(ps) - torch.einsum('mk,mc->mck', ps, ps)
H_ggn = torch.einsum('mcp,mck,mkq->pq', Js, H_lik, Js)
return loss.detach(), H_ggn
return H_lik

def full(self, x, y, **kwargs):
"""Compute the full GGN \\(P \\times P\\) matrix as Hessian approximation
Expand All @@ -236,19 +281,35 @@ def full(self, x, y, **kwargs):
Returns
-------
loss : torch.Tensor
H_ggn : torch.Tensor
H : torch.Tensor
GGN `(parameters, parameters)`
"""
if self.stochastic:
raise ValueError('Stochastic approximation not implemented for full GGN.')
Js, f = self.last_layer_jacobians(x) if self.last_layer else self.jacobians(x)
H_lik = self._sample_H_lik(f) if self.stochastic else self._get_exact_H_lik(f)

if self.last_layer:
Js, f = self.last_layer_jacobians(x)
else:
Js, f = self.jacobians(x)
loss, H_ggn = self._get_full_ggn(Js, f, y)
if H_lik is not None:
H = torch.einsum('bcp,bck,bkq->pq', Js, H_lik, Js)
else: # The case of exact GGN for regression
H = torch.einsum('bcp,bcq->pq', Js, Js)
loss = self.factor * self.lossfunc(f, y)

return loss.detach(), H.detach()

def diag(self, X, y, **kwargs):
Js, f = self.last_layer_jacobians(X) if self.last_layer else self.jacobians(X)
loss = self.factor * self.lossfunc(f, y)

H_lik = self._sample_H_lik(f) if self.stochastic else self._get_exact_H_lik(f)

return loss, H_ggn
if H_lik is not None:
H = torch.einsum('bcp,bck,bkp->p', Js, H_lik, Js)
else: # The case of exact GGN for regression
H = torch.einsum('bcp,bcp->p', Js, Js)

if self.subnetwork_indices is not None:
H = H[self.subnetwork_indices]

return loss.detach(), H.detach()


class EFInterface(CurvatureInterface):
Expand Down Expand Up @@ -293,5 +354,15 @@ def full(self, x, y, **kwargs):
EF `(parameters, parameters)`
"""
Gs, loss = self.gradients(x, y)
H_ef = Gs.T @ Gs
H_ef = torch.einsum('bp,bq->pq', Gs, Gs)
return self.factor * loss.detach(), self.factor * H_ef

def diag(self, X, y, **kwargs):
# Gs is (batchsize, n_params)
Gs, loss = self.gradients(X, y)
diag_ef = torch.einsum('bp,bp->p', Gs, Gs)

if self.subnetwork_indices is not None:
diag_ef = diag_ef[self.subnetwork_indices]

return self.factor * loss.detach(), self.factor * diag_ef
Loading