Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add last-layer Laplace flavors #2

Merged
merged 19 commits into from
Apr 21, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Add support for last-layer with BackPackEF backend and add _model pro…
…perty to BackPackInterface
  • Loading branch information
runame committed Apr 20, 2021
commit f7a84634e11839021fc064c6a067dbea65ce712e
26 changes: 11 additions & 15 deletions laplace/curvature.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,13 @@ class BackPackInterface(CurvatureInterface):
def __init__(self, model, likelihood, last_layer=False):
super().__init__(model, likelihood)
self.last_layer = last_layer
extend(self.model.last_layer) if last_layer else extend(self.model)
extend(self._model)
extend(self.lossfunc)

@property
def _model(self):
return self.model.last_layer if self.last_layer else self.model


class BackPackGGN(BackPackInterface):
"""[summary]
Expand All @@ -69,24 +73,16 @@ def __init__(self, model, likelihood, last_layer=False, stochastic=False):
self.stochastic = stochastic

def _get_diag_ggn(self):
if self.last_layer:
model = self.model.last_layer
else:
model = self.model
if self.stochastic:
return torch.cat([p.diag_ggn_mc.data.flatten() for p in model.parameters()])
return torch.cat([p.diag_ggn_mc.data.flatten() for p in self._model.parameters()])
else:
return torch.cat([p.diag_ggn_exact.data.flatten() for p in model.parameters()])
return torch.cat([p.diag_ggn_exact.data.flatten() for p in self._model.parameters()])

def _get_kron_factors(self):
if self.last_layer:
model = self.model.last_layer
else:
model = self.model
if self.stochastic:
return Kron([p.kfac for p in model.parameters()])
return Kron([p.kfac for p in self._model.parameters()])
else:
return Kron([p.kflr for p in model.parameters()])
return Kron([p.kflr for p in self._model.parameters()])

@staticmethod
def _rescale_kron_factors(kron, M, N):
Expand Down Expand Up @@ -135,15 +131,15 @@ class BackPackEF(BackPackInterface):

def _get_individual_gradients(self):
return torch.cat([p.grad_batch.data.flatten(start_dim=1)
for p in self.model.parameters()], dim=1)
for p in self._model.parameters()], dim=1)

def diag(self, X, y, **kwargs):
f = self.model(X)
loss = self.lossfunc(f, y)
with backpack(SumGradSquared()):
loss.backward()
diag_EF = torch.cat([p.sum_grad_squared.data.flatten()
for p in self.model.parameters()])
for p in self._model.parameters()])

return self.factor * loss.detach(), self.factor ** 2 * diag_EF

Expand Down