Skip to content

Commit

Permalink
Merge pull request #2 from AlexImmer/llla-dev
Browse files Browse the repository at this point in the history
Add last-layer Laplace flavors
  • Loading branch information
aleximmer authored Apr 21, 2021
2 parents 2ab0079 + f7a8463 commit fd0b869
Show file tree
Hide file tree
Showing 14 changed files with 834 additions and 82 deletions.
74 changes: 46 additions & 28 deletions laplace/curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@
from backpack import backpack, extend
from backpack.extensions import DiagGGNExact, DiagGGNMC, KFAC, KFLR, SumGradSquared, BatchGrad

from laplace.jacobians import Jacobians
from laplace.jacobians import jacobians, last_layer_jacobians
from laplace.matrix import Kron


class CurvatureInterface(ABC):

def __init__(self, model, likelihood):
assert likelihood in ['regression', 'classification']
self.likelihood = likelihood
self.model = model
if likelihood == 'regression':
self.lossfunc = MSELoss(reduction='sum')
self.factor = 0.5 # convert to standard Gauss. log N(y|f,1)
else:
self.lossfunc = CrossEntropyLoss(reduction='sum')
self.factor = 1.

@abstractmethod
def full(self, X, y, **kwargs):
pass
Expand All @@ -24,19 +35,29 @@ def kron(self, X, y, **kwargs):
def diag(self, X, y, **kwargs):
pass

def _get_full_ggn(self, Js, f, y):
loss = self.factor * self.lossfunc(f, y)
if self.likelihood == 'regression':
H_ggn = torch.einsum('mkp,mkq->pq', Js, Js)
else:
# second derivative of log lik is diag(p) - pp^T
ps = torch.softmax(f, dim=-1)
H_lik = torch.diag_embed(ps) - torch.einsum('mk,mc->mck', ps, ps)
H_ggn = torch.einsum('mcp,mck,mkq->pq', Js, H_lik, Js)
return loss.detach(), H_ggn


class BackPackInterface(CurvatureInterface):

def __init__(self, model, likelihood):
assert likelihood in ['regression', 'classification']
self.likelihood = likelihood
self.model = extend(model)
if likelihood == 'regression':
self.lossfunc = extend(MSELoss(reduction='sum'))
self.factor = 0.5 # convert to standard Gauss. log N(y|f,1)
else:
self.lossfunc = extend(CrossEntropyLoss(reduction='sum'))
self.factor = 1.
def __init__(self, model, likelihood, last_layer=False):
super().__init__(model, likelihood)
self.last_layer = last_layer
extend(self._model)
extend(self.lossfunc)

@property
def _model(self):
return self.model.last_layer if self.last_layer else self.model


class BackPackGGN(BackPackInterface):
Expand All @@ -47,21 +68,21 @@ class BackPackGGN(BackPackInterface):
--> factor for regression is 0.5 for loss and ggn
"""

def __init__(self, model, likelihood, stochastic=False):
super().__init__(model, likelihood)
def __init__(self, model, likelihood, last_layer=False, stochastic=False):
super().__init__(model, likelihood, last_layer)
self.stochastic = stochastic

def _get_diag_ggn(self):
if self.stochastic:
return torch.cat([p.diag_ggn_mc.data.flatten() for p in self.model.parameters()])
return torch.cat([p.diag_ggn_mc.data.flatten() for p in self._model.parameters()])
else:
return torch.cat([p.diag_ggn_exact.data.flatten() for p in self.model.parameters()])
return torch.cat([p.diag_ggn_exact.data.flatten() for p in self._model.parameters()])

def _get_kron_factors(self):
if self.stochastic:
return Kron([p.kfac for p in self.model.parameters()])
return Kron([p.kfac for p in self._model.parameters()])
else:
return Kron([p.kflr for p in self.model.parameters()])
return Kron([p.kflr for p in self._model.parameters()])

@staticmethod
def _rescale_kron_factors(kron, M, N):
Expand Down Expand Up @@ -97,31 +118,28 @@ def full(self, X, y, **kwargs):
if self.stochastic:
raise ValueError('Stochastic approximation not implemented for full GGN.')

Js, f = Jacobians(self.model, X)
loss = self.factor * self.lossfunc(f, y)
if self.likelihood == 'regression':
H_ggn = torch.einsum('mkp,mkq->pq', Js, Js)
if self.last_layer:
Js, f = last_layer_jacobians(self.model, X)
else:
# second derivative of log lik is diag(p) - pp^T
ps = torch.softmax(f, dim=-1)
H_lik = torch.diag_embed(ps) - torch.einsum('mk,mc->mck', ps, ps)
H_ggn = torch.einsum('mcp,mck,mkq->pq', Js, H_lik, Js)
return loss.detach(), H_ggn
Js, f = jacobians(self.model, X)
loss, H_ggn = self._get_full_ggn(Js, f, y)

return loss, H_ggn


class BackPackEF(BackPackInterface):

def _get_individual_gradients(self):
return torch.cat([p.grad_batch.data.flatten(start_dim=1)
for p in self.model.parameters()], dim=1)
for p in self._model.parameters()], dim=1)

def diag(self, X, y, **kwargs):
f = self.model(X)
loss = self.lossfunc(f, y)
with backpack(SumGradSquared()):
loss.backward()
diag_EF = torch.cat([p.sum_grad_squared.data.flatten()
for p in self.model.parameters()])
for p in self._model.parameters()])

return self.factor * loss.detach(), self.factor ** 2 * diag_EF

Expand Down
123 changes: 123 additions & 0 deletions laplace/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
import torch.nn as nn
from typing import Tuple, Callable, Optional


class FeatureExtractor(nn.Module):
"""Feature extractor for a PyTorch neural network.
A wrapper which returns the output of the penultimate layer in addition to
the output of the last layer for each forward pass. It assumes that the
last layer is linear.
Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.
Arguments
----------
model : torch.nn.Module
PyTorch model
last_layer_name (optional) : str, default=None
If the user already knows the name of the last layer, otherwise it will
be determined automatically.
Attributes
----------
model : torch.nn.Module
The underlying PyTorch model.
last_layer : torch.nn.module
The torch module corresponding to the last layer (has to be instance
of torch.nn.Linear).
Examples
--------
Notes
-----
Limitations:
- Assumes that the last layer is always the same for any forward pass
- Assumes that the last layer is an instance of torch.nn.Linear
"""
def __init__(self, model: nn.Module, last_layer_name: Optional[str] = None) -> None:
super().__init__()
self.model = model
self._features = dict()
if last_layer_name is None:
self._found = False
else:
self.set_last_layer(last_layer_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._found:
# if last and penultimate layers are already known
out = self.model(x)
else:
# if this is the first forward pass
out = self.find_last_layer(x)
return out

def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
out = self.forward(x)
features = self._features[self._last_layer_name]
return out, features

def set_last_layer(self, last_layer_name: str) -> None:
# set last_layer attributes and check if it is linear
self._last_layer_name = last_layer_name
self.last_layer = dict(self.model.named_modules())[last_layer_name]
if not isinstance(self.last_layer, nn.Linear):
raise ValueError('Use model with a linear last layer.')

# set forward hook to extract features in future forward passes
self.last_layer.register_forward_hook(self._get_hook(last_layer_name))

# last layer is now identified and hook is set
self._found = True

def _get_hook(self, name: str) -> Callable:
def hook(_, input, __):
# only accepts one input (expects linear layer)
self._features[name] = input[0].detach()
return hook

def find_last_layer(self, x: torch.Tensor) -> torch.Tensor:
if self._found:
raise ValueError('Last layer is already known.')

act_out = dict()
def get_act_hook(name):
def act_hook(_, input, __):
# only accepts one input (expects linear layer)
if isinstance(input[0], torch.Tensor):
act_out[name] = input[0].detach()
else:
act_out[name] = None
# remove hook
handles[name].remove()
return act_hook

# set hooks for all modules
handles = dict()
for name, module in self.model.named_modules():
handles[name] = module.register_forward_hook(get_act_hook(name))

# check if model has more than one module
# (there might be pathological exceptions)
if len(handles) <= 2:
raise ValueError('The model only has one module.')

# forward pass to find execution order
out = self.model(x)

# find the last layer, store features, return output of forward pass
keys = list(act_out.keys())
for key in reversed(keys):
layer = dict(self.model.named_modules())[key]
if len(list(layer.children())) == 0:
self.set_last_layer(key)

# save features from first forward pass
self._features[key] = act_out[key]

return out

raise ValueError('Something went wrong (all modules have children).')
17 changes: 16 additions & 1 deletion laplace/jacobians.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def cleanup(module):
memory_cleanup(module)


def Jacobians(model, data):
def jacobians(model, data):
# Jacobians are batch x output x params
model = extend(model)
to_stack = []
Expand Down Expand Up @@ -40,3 +40,18 @@ def Jacobians(model, data):
return torch.stack(to_stack, dim=2).transpose(1, 2), f
else:
return Jk.unsqueeze(-1).transpose(1, 2), f


def last_layer_jacobians(model, data):
f, phi = model.forward_with_features(data)
bsize = len(data)
output_size = f.shape[-1]

# calculate Jacobians using the feature vector 'phi'
identity = torch.eye(output_size, device=data.device).unsqueeze(0).tile(bsize, 1, 1)
# Jacobians are batch x output x params
Js = torch.einsum('kp,kij->kijp', phi, identity).reshape(bsize, output_size, -1)
if model.last_layer.bias is not None:
Js = torch.cat([Js, identity], dim=2)

return Js, f
39 changes: 17 additions & 22 deletions laplace/laplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from laplace.utils import parameters_per_layer, invsqrt_precision
from laplace.matrix import Kron
from laplace.curvature import BackPackGGN
from laplace.jacobians import Jacobians
from laplace.jacobians import jacobians


__all__ = ['FullLaplace', 'KronLaplace', 'DiagLaplace']
Expand Down Expand Up @@ -83,7 +83,7 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
def _curv_closure(self, X, y, N):
pass

def fit(self, train_loader, compute_scale=True):
def fit(self, train_loader):
"""Fit the local Laplace approximation at the parameters of the model.
Parameters
Expand All @@ -109,13 +109,6 @@ def fit(self, train_loader, compute_scale=True):
self.n_data = N

self._fit = True
# compute optimal representation of posterior Cov/Prec.
if compute_scale:
self.compute_scale()

@abstractmethod
def compute_scale(self):
pass

def marginal_likelihood(self, prior_precision=None, sigma_noise=None):
"""Compute the Laplace approximation to the marginal likelihood.
Expand Down Expand Up @@ -225,15 +218,15 @@ def predictive_samples(self, X, pred_type='glm', n_samples=100):
if self.likelihood == 'regression':
return samples
return torch.softmax(samples, dim=-1)

else: # 'nn'
return self.nn_predictive_samples(X, n_samples)

def glm_predictive_distribution(self, X):
Js, f_mu = Jacobians(self.model, X)
Js, f_mu = jacobians(self.model, X)
f_var = self.functional_variance(Js)
return f_mu.detach(), f_var.detach()

def nn_predictive_samples(self, X, n_samples=100):
fs = list()
for sample in self.sample(n_samples):
Expand Down Expand Up @@ -371,16 +364,24 @@ def __init__(self, model, likelihood, sigma_noise=1., prior_precision=1.,
super().__init__(model, likelihood, sigma_noise, prior_precision,
temperature, backend)
self.H = torch.zeros(self.n_params, self.n_params, device=self._device)
self._posterior_scale = None

def _curv_closure(self, X, y, N):
return self.backend.full(X, y, N=N)

def compute_scale(self):
self.posterior_scale = invsqrt_precision(self.posterior_precision)
def _compute_scale(self):
self._posterior_scale = invsqrt_precision(self.posterior_precision)

@property
def posterior_scale(self):
if self._posterior_scale is None:
self._compute_scale()
return self._posterior_scale

@property
def posterior_covariance(self):
return self.posterior_scale @ self.posterior_scale.T
scale = self.posterior_scale
return scale @ scale.T

@property
def posterior_precision(self):
Expand Down Expand Up @@ -419,10 +420,8 @@ def _curv_closure(self, X, y, N):
return self.backend.kron(X, y, N=N)

def fit(self, train_loader):
super().fit(train_loader)
# Kron requires postprocessing as all quantities depend on the decomposition.
super().fit(train_loader, compute_scale=True)

def compute_scale(self):
self.H = self.H.decompose()

@property
Expand Down Expand Up @@ -476,10 +475,6 @@ def posterior_precision(self):

return self.H_factor * self.H + self.prior_precision_diag

def compute_scale(self):
# For diagonal this is implemented lazily since computing is for free.
pass

@property
def posterior_scale(self):
return 1 / self.posterior_precision.sqrt()
Expand Down
Loading

0 comments on commit fd0b869

Please sign in to comment.