Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature caching mechanism in LLLA #170

Merged
merged 3 commits into from
Jun 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions examples/calibration_example.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
import warnings
warnings.simplefilter("ignore", UserWarning)

import torch
import torch.distributions as dists
import numpy as np
import helper.wideresnet as wrn
import helper.dataloaders as dl
from helper import util
from netcal.metrics import ECE
warnings.simplefilter('ignore', UserWarning)

from laplace import Laplace
import torch # noqa: E402
import torch.distributions as dists # noqa: E402
import numpy as np # noqa: E402
import helper.wideresnet as wrn # noqa: E402
import helper.dataloaders as dl # noqa: E402
from helper import util # noqa: E402
from netcal.metrics import ECE # noqa: E402

from laplace import Laplace # noqa: E402


np.random.seed(7777)
Expand Down Expand Up @@ -50,9 +51,9 @@ def predict(dataloader, model, laplace=False):
print(f'[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}')

# Laplace
la = Laplace(model, 'classification',
subset_of_weights='last_layer',
hessian_structure='kron')
la = Laplace(
model, 'classification', subset_of_weights='last_layer', hessian_structure='kron'
)
la.fit(train_loader)
la.optimize_prior_precision(method='marglik')

Expand All @@ -61,4 +62,6 @@ def predict(dataloader, model, laplace=False):
ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()

print(f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}')
print(
f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}'
)
26 changes: 14 additions & 12 deletions examples/helper/dataloaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,28 @@


def CIFAR10(train=True, batch_size=None, augm_flag=True):
if batch_size == None:
if batch_size is None:
if train:
batch_size=train_batch_size
batch_size = train_batch_size
else:
batch_size=test_batch_size
batch_size = test_batch_size

transform_base = [transforms.ToTensor()]
transform_train = transforms.Compose([
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
] + transform_base)
transform_train = transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4, padding_mode='reflect'),
]
+ transform_base
)
transform_test = transforms.Compose(transform_base)
transform_train = transforms.RandomChoice([transform_train, transform_test])
transform = transform_train if (augm_flag and train) else transform_test

dataset = datasets.CIFAR10(path, train=train, transform=transform, download=True)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size,
shuffle=train, num_workers=4)
loader = torch.utils.data.DataLoader(
dataset, batch_size=batch_size, shuffle=train, num_workers=0
)

return loader

Expand All @@ -38,9 +42,7 @@ def get_sinusoid_example(n_data=150, sigma_noise=0.3, batch_size=150):
X_train = (torch.rand(n_data) * 8).unsqueeze(-1)
y_train = torch.sin(X_train) + torch.randn_like(X_train) * sigma_noise
train_loader = data_utils.DataLoader(
data_utils.TensorDataset(X_train, y_train),
batch_size=batch_size
data_utils.TensorDataset(X_train, y_train), batch_size=batch_size
)
X_test = torch.linspace(-5, 13, 500).unsqueeze(-1)
return X_train, y_train, train_loader, X_test

37 changes: 30 additions & 7 deletions laplace/lllaplace.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,25 +166,48 @@ def _glm_predictive_distribution(self, X, joint=False):

def _nn_predictive_samples(self, X, n_samples=100, generator=None, **model_kwargs):
fs = list()
for sample in self.sample(n_samples, generator):

for i, sample in enumerate(self.sample(n_samples, generator)):
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
vector_to_parameters(sample, self.model.last_layer.parameters())
f = self.model(X.to(self._device), **model_kwargs)

if i == 0:
# Cache features at the first iteration
f, feats = self.model.forward_with_features(
X.to(self._device), **model_kwargs
)
else:
# Used the cached features for the rest iterations
f = self.model.last_layer(feats)

fs.append(f.detach() if not self.enable_backprop else f)

vector_to_parameters(self.mean, self.model.last_layer.parameters())
fs = torch.stack(fs)

if self.likelihood == 'classification':
fs = torch.softmax(fs, dim=-1)

return fs

def _nn_predictive_classification(
self, X, n_samples=100, generator=None, **model_kwargs
):
py = 0
for sample in self.sample(n_samples, generator):

for i, sample in enumerate(self.sample(n_samples, generator)):
wiseodd marked this conversation as resolved.
Show resolved Hide resolved
vector_to_parameters(sample, self.model.last_layer.parameters())
# TODO: Implement with a single forward pass until last layer.
logits = self.model(X.to(self._device), **model_kwargs).detach()
py += torch.softmax(logits, dim=-1) / n_samples

if i == 0:
# Cache features at the first iteration
logits, feats = self.model.forward_with_features(
X.to(self._device), **model_kwargs
)
else:
# Used the cached features for the rest iterations
logits = self.model.last_layer(feats)

py += torch.softmax(logits.detach(), dim=-1) / n_samples

vector_to_parameters(self.mean, self.model.last_layer.parameters())
return py

Expand Down Expand Up @@ -279,7 +302,7 @@ def __init__(
backend=None,
last_layer_name=None,
damping=False,
**backend_kwargs
**backend_kwargs,
):
self.damping = damping
super().__init__(
Expand Down
17 changes: 13 additions & 4 deletions laplace/utils/feature_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,13 @@ class FeatureExtractor(nn.Module):
if the name of the last layer is already known, otherwise it will
be determined automatically.
"""

def __init__(
self, model: nn.Module, last_layer_name: Optional[str] = None,
enable_backprop: bool = False) -> None:
self,
model: nn.Module,
last_layer_name: Optional[str] = None,
enable_backprop: bool = False,
) -> None:
super().__init__()
self.model = model
self._features = dict()
Expand Down Expand Up @@ -54,7 +58,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.model(x)
return out

def forward_with_features(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
def forward_with_features(
self, x: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass which returns the output of the penultimate layer along
with the output of the last layer. If the last layer is not known yet,
it will be determined when this function is called for the first time.
Expand Down Expand Up @@ -90,9 +96,10 @@ def _get_hook(self, name: str) -> Callable:
def hook(_, input, __):
# only accepts one input (expects linear layer)
self._features[name] = input[0]

if not self.enable_backprop:
self._features[name] = self._features[name].detach()

return hook

def find_last_layer(self, x: torch.Tensor) -> torch.Tensor:
Expand All @@ -112,6 +119,7 @@ def find_last_layer(self, x: torch.Tensor) -> torch.Tensor:
raise ValueError('Last layer is already known.')

act_out = dict()

def get_act_hook(name):
def act_hook(_, input, __):
# only accepts one input (expects linear layer)
Expand All @@ -121,6 +129,7 @@ def act_hook(_, input, __):
act_out[name] = None
# remove hook
handles[name].remove()

return act_hook

# set hooks for all modules
Expand Down
Loading