Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add last-layer Laplace flavors #2

Merged
merged 19 commits into from
Apr 21, 2021
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 64 additions & 26 deletions laplace/curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,23 @@
from backpack import backpack, extend
from backpack.extensions import DiagGGNExact, DiagGGNMC, KFAC, KFLR, SumGradSquared, BatchGrad

from laplace.jacobians import Jacobians
from laplace.jacobians import Jacobians, LLJacobians
aleximmer marked this conversation as resolved.
Show resolved Hide resolved
from laplace.matrix import Kron


class CurvatureInterface(ABC):

def __init__(self, model, likelihood):
assert likelihood in ['regression', 'classification']
self.likelihood = likelihood
self.model = model
if likelihood == 'regression':
self.lossfunc = MSELoss(reduction='sum')
self.factor = 0.5 # convert to standard Gauss. log N(y|f,1)
else:
self.lossfunc = CrossEntropyLoss(reduction='sum')
self.factor = 1.

@abstractmethod
def full(self, X, y, **kwargs):
pass
Expand All @@ -24,19 +35,44 @@ def kron(self, X, y, **kwargs):
def diag(self, X, y, **kwargs):
pass

def _get_full_ggn(self, Js, f, y):
loss = self.factor * self.lossfunc(f, y)
if self.likelihood == 'regression':
H_ggn = torch.einsum('mkp,mkq->pq', Js, Js)
else:
# second derivative of log lik is diag(p) - pp^T
ps = torch.softmax(f, dim=-1)
H_lik = torch.diag_embed(ps) - torch.einsum('mk,mc->mck', ps, ps)
H_ggn = torch.einsum('mcp,mck,mkq->pq', Js, H_lik, Js)
return loss.detach(), H_ggn


class LastLayer(CurvatureInterface):

def __init__(self, model, likelihood, backend, **kwargs):
super().__init__(model, likelihood)
self.backend = backend(self.model, self.likelihood, last_layer=True, **kwargs)
runame marked this conversation as resolved.
Show resolved Hide resolved

def diag(self, X, y, **kwargs):
return self.backend.diag(X, y, **kwargs)

def kron(self, X, y, **kwargs):
return self.backend.kron(X, y, **kwargs)

def full(self, X, y, **kwargs):
Js, f = LLJacobians(self.model, X)
loss, H = self._get_full_ggn(Js, f, y)

return loss, H


class BackPackInterface(CurvatureInterface):

def __init__(self, model, likelihood):
assert likelihood in ['regression', 'classification']
self.likelihood = likelihood
self.model = extend(model)
if likelihood == 'regression':
self.lossfunc = extend(MSELoss(reduction='sum'))
self.factor = 0.5 # convert to standard Gauss. log N(y|f,1)
else:
self.lossfunc = extend(CrossEntropyLoss(reduction='sum'))
self.factor = 1.
def __init__(self, model, likelihood, last_layer=False):
super().__init__(model, likelihood)
self.last_layer = last_layer
extend(self.model.last_layer) if last_layer else extend(self.model)
extend(self.lossfunc)


class BackPackGGN(BackPackInterface):
Expand All @@ -47,21 +83,29 @@ class BackPackGGN(BackPackInterface):
--> factor for regression is 0.5 for loss and ggn
"""

def __init__(self, model, likelihood, stochastic=False):
super().__init__(model, likelihood)
def __init__(self, model, likelihood, last_layer=False, stochastic=False):
super().__init__(model, likelihood, last_layer)
self.stochastic = stochastic

def _get_diag_ggn(self):
if self.last_layer:
model = self.model.last_layer
else:
model = self.model
if self.stochastic:
return torch.cat([p.diag_ggn_mc.data.flatten() for p in self.model.parameters()])
return torch.cat([p.diag_ggn_mc.data.flatten() for p in model.parameters()])
else:
return torch.cat([p.diag_ggn_exact.data.flatten() for p in self.model.parameters()])
return torch.cat([p.diag_ggn_exact.data.flatten() for p in model.parameters()])

def _get_kron_factors(self):
if self.last_layer:
model = self.model.last_layer
else:
model = self.model
if self.stochastic:
return Kron([p.kfac for p in self.model.parameters()])
return Kron([p.kfac for p in model.parameters()])
else:
return Kron([p.kflr for p in self.model.parameters()])
return Kron([p.kflr for p in model.parameters()])

@staticmethod
def _rescale_kron_factors(kron, M, N):
Expand Down Expand Up @@ -98,15 +142,9 @@ def full(self, X, y, **kwargs):
raise ValueError('Stochastic approximation not implemented for full GGN.')

Js, f = Jacobians(self.model, X)
loss = self.factor * self.lossfunc(f, y)
if self.likelihood == 'regression':
H_ggn = torch.einsum('mkp,mkq->pq', Js, Js)
else:
# second derivative of log lik is diag(p) - pp^T
ps = torch.softmax(f, dim=-1)
H_lik = torch.diag_embed(ps) - torch.einsum('mk,mc->mck', ps, ps)
H_ggn = torch.einsum('mcp,mck,mkq->pq', Js, H_lik, Js)
return loss.detach(), H_ggn
loss, H_ggn = self._get_full_ggn(Js, f, y)

return loss, H_ggn


class BackPackEF(BackPackInterface):
Expand Down
123 changes: 123 additions & 0 deletions laplace/feature_extractor.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
import torch
import torch.nn as nn
from typing import Tuple, Callable, Optional


class FeatureExtractor(nn.Module):
"""Feature extractor for a PyTorch neural network.
A wrapper which returns the output of the penultimate layer in addition to
the output of the last layer for each forward pass. It assumes that the
last layer is linear.
Based on https://gist.github.com/fkodom/27ed045c9051a39102e8bcf4ce31df76.

Arguments
----------
model : torch.nn.Module
PyTorch model

last_layer_name (optional) : str, default=None
If the user already knows the name of the last layer, otherwise it will
be determined automatically.

Attributes
----------
model : torch.nn.Module
The underlying PyTorch model.

last_layer : torch.nn.module
The torch module corresponding to the last layer (has to be instance
of torch.nn.Linear).

Examples
--------

Notes
-----
Limitations:
- Assumes that the last layer is always the same for any forward pass
- Assumes that the last layer is an instance of torch.nn.Linear
"""
def __init__(self, model: nn.Module, last_layer_name: Optional[str] = None) -> None:
super().__init__()
self.model = model
self._features = dict()
if last_layer_name is None:
self._found = False
else:
self.set_last_layer(last_layer_name)

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self._found:
# if last and penultimate layers are already known
out = self.model(x)
else:
# if this is the first forward pass
out = self.find_last_layer(x)
return out

def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
out = self.forward(x)
features = self._features[self._last_layer_name]
return out, features

def set_last_layer(self, last_layer_name: str) -> None:
# set last_layer attributes and check if it is linear
self._last_layer_name = last_layer_name
self.last_layer = dict(self.model.named_modules())[last_layer_name]
if not isinstance(self.last_layer, nn.Linear):
raise ValueError('Use model with a linear last layer.')

# set forward hook to extract features in future forward passes
self.last_layer.register_forward_hook(self._get_hook(last_layer_name))

# last layer is now identified and hook is set
self._found = True

def _get_hook(self, name: str) -> Callable:
def hook(_, input, __):
# only accepts one input (expects linear layer)
self._features[name] = input[0].detach()
return hook

def find_last_layer(self, x: torch.Tensor) -> torch.Tensor:
if self._found:
raise ValueError('Last layer is already known.')

act_out = dict()
def get_act_hook(name):
def act_hook(_, input, __):
# only accepts one input (expects linear layer)
if isinstance(input[0], torch.Tensor):
act_out[name] = input[0].detach()
else:
act_out[name] = None
# remove hook
handles[name].remove()
return act_hook

# set hooks for all modules
handles = dict()
for name, module in self.model.named_modules():
handles[name] = module.register_forward_hook(get_act_hook(name))

# check if model has more than one module
# (there might be pathological exceptions)
if len(handles) <= 2:
raise ValueError('The model only has one module.')

# forward pass to find execution order
out = self.model(x)

# find the last layer, store features, return output of forward pass
keys = list(act_out.keys())
for key in reversed(keys):
layer = dict(self.model.named_modules())[key]
if len(list(layer.children())) == 0:
self.set_last_layer(key)

# save features from first forward pass
self._features[key] = act_out[key]

return out

raise ValueError('Something went wrong (all modules have children).')
15 changes: 15 additions & 0 deletions laplace/jacobians.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,3 +40,18 @@ def Jacobians(model, data):
return torch.stack(to_stack, dim=2).transpose(1, 2), f
else:
return Jk.unsqueeze(-1).transpose(1, 2), f


def LLJacobians(model, data):
f, phi = model.forward_with_features(data)
bsize = len(data)
output_size = f.shape[-1]

# calculate Jacobians using the feature vector 'phi'
identity = torch.eye(output_size, device=data.device).unsqueeze(0).tile(bsize, 1, 1)
# Jacobians are batch x output x params
Js = torch.einsum('kp,kij->kijp', phi, identity).reshape(bsize, output_size, -1)
if model.last_layer.bias is not None:
Js = torch.cat([Js, identity], dim=2)

return Js, f
Loading