Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Functional Laplace Updated #192

Merged
merged 129 commits into from
Jul 16, 2024
Merged
Show file tree
Hide file tree
Changes from 107 commits
Commits
Show all changes
129 commits
Select commit Hold shift + click to select a range
e3f0481
init commit
metodj Sep 15, 2021
bcb4bfb
FunctionalLaplace blueprint
metodmove Sep 18, 2021
b2a2ff4
added a test that will serve as a sanity check for correctness
metodmove Sep 20, 2021
a793a04
test small refactor
metodmove Sep 25, 2021
1b50f1d
init GP matrices and SoD dataloader
metodmove Sep 25, 2021
84bd1a7
initial full naive GP implementation
metodmove Sep 26, 2021
8050bff
fixed bugs in full naive GP inference
metodmove Sep 27, 2021
2ac2b81
test gp_equivalence (almost) passing
metodmove Sep 29, 2021
cdbadae
add GP to regression_example
metodmove Oct 2, 2021
3e4ffd2
float32 vs float64 magic
metodmove Oct 2, 2021
11943c1
Cholesky solves float32 vs float64 conundrum
metodmove Oct 5, 2021
e43eb0e
minor
metodmove Oct 11, 2021
8b51f62
merge remove-abc into functional-laplace
metodmove Oct 19, 2021
aab27a6
fix multivariate regression GP bug
metodmove Oct 21, 2021
583cef6
GP inference parameters
metodmove Oct 21, 2021
cf00987
GP classification with diagonal_L=True
metodmove Oct 25, 2021
0e28c4a
start independent_gp_kernels=True
metodmove Oct 25, 2021
d28cc6b
block-diagonal kernel analysis
metodmove Oct 28, 2021
4b568cb
warning for multivariate regression
metodmove Nov 4, 2021
4f58447
refactor kernel method
metodmove Nov 4, 2021
bff30f4
prior factor for SoD and refactor __call__
metodmove Nov 6, 2021
d61b4fa
log_marginal_likelihood GP start
metodmove Nov 6, 2021
3a5a1be
scatter for logp(f)
metodmove Nov 8, 2021
dec9514
fix gp marginal likelihood bugs
metodmove Nov 8, 2021
4dc7016
GP predictive_samples
metodmove Nov 11, 2021
7ea6328
start with tests
metodmove Nov 12, 2021
c91048e
minor
metodmove Nov 12, 2021
7376204
more tests and refactor
metodmove Nov 13, 2021
ab5058a
docstrings
metodmove Nov 13, 2021
0ef5ae3
add last layer functional laplace
metodmove Nov 13, 2021
81742bf
minor
metodmove Nov 13, 2021
6935dcb
remove files
metodmove Nov 13, 2021
9118733
classification gp lml
metodmove Nov 13, 2021
bce2603
minor
metodmove Nov 30, 2021
475ea96
start of BackPackGP class refactor
metodmove Nov 30, 2021
7b1ae50
remove BackPackGP class
metodmove Nov 30, 2021
d7f6dae
remove gp_jacobians
metodmove Dec 1, 2021
4718328
docstrings functional laplace and public methods renaming
metodmove Dec 2, 2021
4b84f5d
minor docstring fix
metodmove Dec 2, 2021
1ff1c46
merge main
metodmove Dec 8, 2021
ed2a439
merge functional-laplace
metodmove Dec 8, 2021
4eddeaf
marginal likelihood docs
metodmove Dec 8, 2021
c9145a1
remove diagonal_L
metodmove Dec 8, 2021
d417934
BackPackGGN default backend
metodmove Dec 10, 2021
19d3761
asdl for functional laplace
metodmove Dec 10, 2021
e60c2b4
merge conflicts
metodmove Dec 10, 2021
8c83630
log_marginal_likelihood in BaseLaplace refactor
metodmove Dec 10, 2021
6c3d886
map_estimate remove from ParametricLaplace
metodmove Dec 13, 2021
4a08a4b
remove _check_fit
metodmove Dec 13, 2021
fcadad3
address merge conflicts
metodmove Feb 7, 2022
cab9983
calibration example start
metodmove Feb 7, 2022
052c058
resolve merge conflicts
Dec 19, 2022
c716e4e
more merge conflicts
Dec 19, 2022
999b250
FunctionalLaplace CIFAR calibration experiment start
Dec 19, 2022
9fc00ea
minor
Dec 19, 2022
4ae0d1b
refactor so that refitting in log_marginal_likelihood for functional …
Dec 20, 2022
26ccf2b
gp marginal likelihood fix
Dec 20, 2022
f8e8455
isotropic priors check
Dec 20, 2022
a0e41f3
cleanup examples
Dec 21, 2022
827ae21
gp calibration notebook
Dec 21, 2022
7557615
minor
Dec 21, 2022
b6bde70
inducing points CIFAR experiment
Dec 21, 2022
88eda9d
transfer model from bnn-preds repo
Dec 21, 2022
25cd5d7
inducing points FMNIST CNN
Dec 21, 2022
73448be
gp calibration example
metodj Dec 23, 2022
566c367
gp continue
metodj Dec 29, 2022
9690e6b
subset_of_weights=all experiment
Dec 29, 2022
5f2efc6
fixed prior precision
Dec 29, 2022
da7bb18
fixed prior precision
Dec 29, 2022
d82c149
ensure that input is differentiable
Dec 29, 2022
6c7397e
further optimize delta experiment
Dec 30, 2022
92bfacf
run for larger delta
Jan 3, 2023
4b08edb
last-layer debug
Jan 6, 2023
87abeeb
inference speed-up
Jan 6, 2023
7747168
minor
Jan 6, 2023
09d410a
einsum memory investigation
Jan 6, 2023
47eebfe
CV working
Jan 8, 2023
648b3ed
rebuild on Sigma_inv
Jan 9, 2023
eef9a15
clean
Jan 10, 2023
8967f9d
validate no_grad
Jan 10, 2023
edf6285
Functional laplace gp calibration (#1)
metodj Jan 10, 2023
edd83b1
fix tests
Jan 10, 2023
88b44d7
Merge branch 'functional-laplace-gp-calibration' into functional-laplace
Jan 10, 2023
44fb5c2
clean calibration-gp example
Jan 10, 2023
c9c068d
minor
Jan 10, 2023
39000ae
minor2
Jan 10, 2023
5cb1354
markdown example start
Jan 11, 2023
d8ec55b
markdown wrapup
Jan 11, 2023
03f0989
minor
Jan 12, 2023
005c3c7
minor
Jan 12, 2023
e20ac6c
Functional laplace memory investigation (#2)
metodj Feb 21, 2023
35ec107
Functional laplace memory investigation (#3)
metodj Feb 21, 2023
77a32f4
increase batch size in the example
metodj Feb 21, 2023
5de3980
minor
metodj Feb 21, 2023
d58fb43
Initial Merge from metodj Functional Laplace
Ludvins May 16, 2024
bf58f0d
Push Branch
Ludvins May 16, 2024
dd21349
Fix merge and make it functional
Ludvins May 18, 2024
7be6dda
Add quotes to pip install with tests for zsh compatibility
Ludvins May 18, 2024
06d7b73
Fix and pass unit tests for Functional Laplace
Ludvins May 18, 2024
57ff2c3
Add some comments to functions
Ludvins May 19, 2024
43e3d47
Fix Calibration GP example
Ludvins May 20, 2024
6614e02
Update Calibration GP Example
Ludvins May 24, 2024
9e6ee1f
Delete README
Ludvins May 27, 2024
7b8933d
Merge branch 'aleximmer:main' into main
Ludvins May 31, 2024
ad23e86
Update README.md
Ludvins Jun 4, 2024
f1c13f1
Update README.md
Ludvins Jun 4, 2024
8920f24
Formatting
Ludvins Jun 4, 2024
aa662ac
Typelinting, unittests and FunctionalLLLaplace serialization fixed
Ludvins Jun 10, 2024
c3d8e15
enable_backprop unused for last layer curvature
Ludvins Jun 10, 2024
ad24df3
Merge branch 'main' into main
Ludvins Jun 10, 2024
6eb5655
Merge branch 'main' into main
Ludvins Jun 11, 2024
850d75b
Merge branch 'aleximmer:main' into main
Ludvins Jun 11, 2024
35c7571
Merge branch 'main' into main
Ludvins Jun 11, 2024
03f8609
Add download link for pre-trained model
Ludvins Jun 11, 2024
989c323
Merge branch 'main' into main
Ludvins Jun 12, 2024
864b902
Correct enabe_backprop
Ludvins Jun 12, 2024
fa3f80c
Ruff tests and add pytest-mock to action requirements
Ludvins Jun 12, 2024
772c616
Ruff check and MutableDict input
Ludvins Jun 12, 2024
1320b1d
Merge branch 'main' into main
Ludvins Jun 15, 2024
67d99ba
Fix dtype in unit test
Ludvins Jun 22, 2024
c3fc541
Fix enable_backprop in LastLayer
Ludvins Jun 22, 2024
13d3afb
Fix imports and types
Ludvins Jun 22, 2024
90afca4
Delete debug prints
Ludvins Jun 22, 2024
64b1aa2
Merge branch 'main' into main
Ludvins Jun 24, 2024
9a9f546
Small refactors
Ludvins Jun 24, 2024
ab56b52
Created _glm_forward_call and _glm_predictive_samples in BaseLaplace
Ludvins Jun 24, 2024
bcdf1b4
Merge branch 'main' into main
Ludvins Jul 8, 2024
5ceb1c8
Rename Variables and Fix _glm_predictive
Ludvins Jul 8, 2024
02c58ce
Warning Consistency
Ludvins Jul 12, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,4 @@ data/
.DS_Store

state_dict.bin
/temp
5 changes: 3 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

[![Main](https://travis-ci.com/AlexImmer/Laplace.svg?token=rpuRxEjQS6cCZi7ptL9y&branch=main)](https://travis-ci.com/AlexImmer/Laplace)


The laplace package facilitates the application of Laplace approximations for entire neural networks, subnetworks of neural networks, or just their last layer.
The package enables posterior approximations, marginal-likelihood estimation, and various posterior predictive computations.
The library documentation is available at [https://aleximmer.github.io/Laplace](https://aleximmer.github.io/Laplace).
Expand Down Expand Up @@ -35,7 +36,7 @@ For development purposes, clone the repository and then install:
# or after cloning the repository for development
pip install -e .
# run tests
pip install -e .[tests]
pip install -e '.[tests]'
pytest tests/
```

Expand Down Expand Up @@ -273,7 +274,7 @@ torch.load(..., map_location='cpu')
## Structure
The laplace package consists of two main components:

1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'` and `'diag'`). This results in _nine_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, and `laplace.DiagLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
1. The subclasses of [`laplace.BaseLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/baselaplace.py) that implement different sparsity structures: different subsets of weights (`'all'`, `'subnetwork'` and `'last_layer'`) and different structures of the Hessian approximation (`'full'`, `'kron'`, `'lowrank'`, `'diag'` and `'gp'`). This results in _ten_ currently available options: `laplace.FullLaplace`, `laplace.KronLaplace`, `laplace.DiagLaplace`, `laplace.FunctionalLaplace` the corresponding last-layer variations `laplace.FullLLLaplace`, `laplace.KronLLLaplace`, `laplace.DiagLLLaplace` and `laplace.FunctionalLLLaplace` (which are all subclasses of [`laplace.LLLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/lllaplace.py)), [`laplace.SubnetLaplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/subnetlaplace.py) (which only supports `'full'` and `'diag'` Hessian approximations) and `laplace.LowRankLaplace` (which only supports inference over `'all'` weights). All of these can be conveniently accessed via the [`laplace.Laplace`](https://github.com/AlexImmer/Laplace/blob/main/laplace/laplace.py) function.
2. The backends in [`laplace.curvature`](https://github.com/AlexImmer/Laplace/blob/main/laplace/curvature/) which provide access to Hessian approximations of
the corresponding sparsity structures, for example, the diagonal GGN.

Expand Down
136 changes: 136 additions & 0 deletions examples/calibration_gp_example.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
## Full example: Functional Laplace (GP) on FMNIST image classifier
Applying the General-Gauss-Newton (GGN) approximation to the Hessian in the Laplace approximation (LA) of the BNN posterior
turns the underlying probabilistic model from a BNN into a generalized linear model (GLM).
This GLM is equivalent to a Gaussian Process (GP) with a particular kernel [1, 2].

In this notebook, we will show how to use `laplace` library to perform GP inference on top of a *pre-trained* neural network.

Note that a GPU with CUDA support is needed for this example. We recommend using a GPU with at least 24 GB of memory. If less memory is available, we suggest reducing `BATCH_SIZE` below.

#### Data loading

First, let us load the FMIST dataset. The helper scripts for FMNIST and pre-trained CNN are available in the `examples/helper` directory in the main repository.

```python
import numpy as np
import torch
from torch.utils.data import DataLoader
import torch.distributions as dists
from netcal.metrics import ECE

from helper.util_gp import get_dataset, CIFAR10Net
from laplace import Laplace

np.random.seed(7777)
torch.manual_seed(7777)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

assert torch.cuda.is_available()

DATASET = 'FMNIST'
BATCH_SIZE = 256
ds_train, ds_test = get_dataset(DATASET, False, 'cuda')
train_loader = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False)
targets = torch.cat([y for x, y in test_loader], dim=0).cpu()
```

#### Load a pre-trained model

Next, we load a pre-trained CNN model. The code to train the model can be found in [BNN-predictions repo](https://github.com/AlexImmer/BNN-predictions).

``` python
MODEL_NAME = 'FMNIST_CNN_10_2.2e+02.pt'
model = CIFAR10Net(ds_train.channels, ds_train.K, use_tanh=True).to('cuda')
state = torch.load(f'helper/models/{MODEL_NAME}')
model.load_state_dict(state['model'])
model = model.cuda()
prior_precision = state['delta']
```

To simplify the downstream tasks, we will use the following helper function to make predictions. It simply iterates through all minibatches and obtains the predictive probabilities of the FMNIST classes.

``` python
@torch.no_grad()
def predict(dataloader, model, laplace=False):
py = []

for x, _ in dataloader:
if laplace:
py.append(model(x.cuda()))
else:
py.append(torch.softmax(model(x.cuda()), dim=-1))

return torch.cat(py).cpu().numpy()
```

#### The calibration of MAP

We are now ready to see how calibrated is the model. The metrics we use are the expected calibration error (ECE, Naeni et al., AAAI 2015) and the negative (Categorical) log-likelihood. Note that lower values are better for both these metrics.

First, let us inspect the MAP model. We shall use the [`netcal`](https://github.com/fabiankueppers/calibration-framework) library to easily compute the ECE.

``` python
probs_map = predict(test_loader, model, laplace=False)
acc_map = (probs_map.argmax(-1) == targets).float().mean()
ece_map = ECE(bins=15).measure(probs_map.numpy(), targets.numpy())
nll_map = -dists.Categorical(probs_map).log_prob(targets).mean()

print(f'[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}')
```

Running this snippet, we would get:

```
[MAP] Acc.: 94.8%; ECE: 2.0%; NLL: 0.172
```

### The calibration of Laplace

Next, we run Laplace-GP inference to calibrate neural network's predictions. Since running exact GP inference is computationally infeasible, we perform Subset-of-Datapoints (SoD) [3] approximation here. In the code below, `m`denotes the number of datapoints used in the SoD posterior.

Execution of the cell below can take up to 5min (depending on the exact hardware used).

``` python
for m in [50, 200, 800, 1600]:
print(f'Fitting Laplace-GP for m={m}')
la = Laplace(model, 'classification',
subset_of_weights='all',
hessian_structure='gp',
diagonal_kernel=True, M=m,
prior_precision=prior_precision)
la.fit(train_loader)

probs_laplace = predict(test_loader, la, laplace=True)
acc_laplace = (probs_laplace.argmax(-1) == targets).float().mean()
ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()

print(f'[Laplace-GP, m={m}] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}')
```

```
Fitting Laplace-GP for m=50
[Laplace] Acc.: 91.6%; ECE: 1.5%; NLL: 0.252
Fitting Laplace-GP for m=200
[Laplace] Acc.: 91.5%; ECE: 1.1%; NLL: 0.252
Fitting Laplace-GP for m=800
[Laplace] Acc.: 91.4%; ECE: 0.8%; NLL: 0.254
Fitting Laplace-GP for m=1600
[Laplace] Acc.: 91.3%; ECE: 0.7%; NLL: 0.257
```

Notice that the post-hoc Laplace-GP inference does not have a significant impact on the accuracy, yet it improves the calibration (in terms of ECE) of the MAP model substantially.
<br />
<br />
<br />
<br />

### References
[1] Khan, Mohammad Emtiyaz E., et al. "Approximate inference turns deep networks into gaussian processes." Advances in neural information processing systems 32 (2019)

[2] Immer, Alexander, Maciej Korzepa, and Matthias Bauer. "Improving predictions of Bayesian neural nets via local linearization." International Conference on Artificial Intelligence and Statistics. PMLR, 2021

[3] Rasmussen, Carl Edward. "Gaussian processes in machine learning." Springer, 2004

78 changes: 78 additions & 0 deletions examples/calibration_gp_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import warnings

import numpy as np
import torch
import torch.distributions as dists
from helper.util_gp import CIFAR10Net, get_dataset
from netcal.metrics import ECE
from torch.utils.data import DataLoader

from laplace import Laplace

np.random.seed(7777)
torch.manual_seed(7777)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = True

warnings.simplefilter('ignore', UserWarning)


assert torch.cuda.is_available()

DATASET = 'FMNIST'
BATCH_SIZE = 25
ds_train, ds_test = get_dataset(DATASET, False, 'cuda')
train_loader = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True)
test_loader = DataLoader(ds_test, batch_size=BATCH_SIZE, shuffle=False)
targets = torch.cat([y for x, y in test_loader], dim=0).cpu()

MODEL_NAME = 'FMNIST_CNN_10_2.2e+02.pt'
model = CIFAR10Net(ds_train.channels, ds_train.K, use_tanh=True).to('cuda')
state = torch.load(f'helper/models/{MODEL_NAME}')
model.load_state_dict(state['model'])
model = model.cuda()
prior_precision = state['delta']


@torch.no_grad()
def predict(dataloader, model, laplace=False):
py = []

for x, _ in dataloader:
if laplace:
py.append(model(x.cuda()))
else:
py.append(torch.softmax(model(x.cuda()), dim=-1))

return torch.cat(py).cpu()


probs_map = predict(test_loader, model, laplace=False)
acc_map = (probs_map.argmax(-1) == targets).float().mean()
ece_map = ECE(bins=15).measure(probs_map.numpy(), targets.numpy())
nll_map = -dists.Categorical(probs_map).log_prob(targets).mean()

print(f'[MAP] Acc.: {acc_map:.1%}; ECE: {ece_map:.1%}; NLL: {nll_map:.3}')

for m in [50, 200, 800, 1600]:
print(f'Fitting Laplace-GP for m={m}')
la = Laplace(
model,
'classification',
subset_of_weights='all',
hessian_structure='gp',
diagonal_kernel=True,
M=m,
prior_precision=prior_precision,
)
la.fit(train_loader)
la.optimize_prior_precision(method='marglik', progress_bar=True)

probs_laplace = predict(test_loader, la, laplace=True)
acc_laplace = (probs_laplace.argmax(-1) == targets).float().mean()
ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()

print(
f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}'
)
Binary file added examples/helper/models/FMNIST_CNN_10_2.2e+02.pt
Ludvins marked this conversation as resolved.
Show resolved Hide resolved
Binary file not shown.
119 changes: 119 additions & 0 deletions examples/helper/util_gp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
import os

import torch
import torchvision.datasets as dset
import torchvision.transforms as transforms
from deepobs.pytorch.testproblems.testproblems_utils import (tfconv2d,
tfmaxpool2d)
from torch import nn
from torchvision.datasets import VisionDataset

PACKAGE_DIR = os.path.dirname(os.path.realpath(__file__))
ROOT = '/'.join(PACKAGE_DIR.split('/')[:-1])
DATA_DIR = ROOT + '/data'

MNIST_transform = transforms.ToTensor()


class QuickDS(VisionDataset):
def __init__(self, ds, device):
self.D = [
(ds[i][0].to(device).requires_grad_(), torch.tensor(ds[i][1]).to(device))
for i in range(len(ds))
]
self.K = ds.K
self.channels = ds.channels
self.pixels = ds.pixels

def __getitem__(self, index):
return self.D[index]

def __len__(self):
return len(self.D)


def get_dataset(dataset, double, device=None):
if dataset == 'FMNIST':
ds_train = FMNIST(train=True, double=double)
ds_test = FMNIST(train=False, double=double)
else:
raise ValueError('Invalid dataset argument')
if device is not None:
return QuickDS(ds_train, device), QuickDS(ds_test, device)
else:
return ds_train, ds_test


class FMNIST(dset.FashionMNIST):
def __init__(
self,
root=DATA_DIR,
train=True,
download=True,
transform=MNIST_transform,
double=False,
):
super().__init__(root=root, train=train, download=download, transform=transform)
self.K = 10
self.pixels = 28
self.channels = 1
if double:
self.data = self.data.double()
self.targets = self.targets.double()


class CIFAR10Net(nn.Sequential):
"""
Deepobs network with optional last sigmoid activation (instead of relu)
In Deepobs called `net_cifar10_3c3d`
"""

def __init__(self, in_channels=3, n_out=10, use_tanh=False):
super(CIFAR10Net, self).__init__()
self.output_size = n_out
activ = nn.Tanh if use_tanh else nn.ReLU

self.add_module(
'conv1', tfconv2d(in_channels=in_channels, out_channels=64, kernel_size=5)
)
self.add_module('relu1', nn.ReLU())
self.add_module(
'maxpool1', tfmaxpool2d(kernel_size=3, stride=2, tf_padding_type='same')
)

self.add_module(
'conv2', tfconv2d(in_channels=64, out_channels=96, kernel_size=3)
)
self.add_module('relu2', nn.ReLU())
self.add_module(
'maxpool2', tfmaxpool2d(kernel_size=3, stride=2, tf_padding_type='same')
)

self.add_module(
'conv3',
tfconv2d(
in_channels=96, out_channels=128, kernel_size=3, tf_padding_type='same'
),
)
self.add_module('relu3', nn.ReLU())
self.add_module(
'maxpool3', tfmaxpool2d(kernel_size=3, stride=2, tf_padding_type='same')
)

self.add_module('flatten', nn.Flatten())

self.add_module('dense1', nn.Linear(in_features=3 * 3 * 128, out_features=512))
self.add_module('relu4', activ())
self.add_module('dense2', nn.Linear(in_features=512, out_features=256))
self.add_module('relu5', activ())
self.add_module('dense3', nn.Linear(in_features=256, out_features=n_out))

# init the layers
for module in self.modules():
if isinstance(module, nn.Conv2d):
nn.init.constant_(module.bias, 0.0)
nn.init.xavier_normal_(module.weight)

if isinstance(module, nn.Linear):
nn.init.constant_(module.bias, 0.0)
nn.init.xavier_uniform_(module.weight)
3 changes: 2 additions & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
botorch==0.8.2
gpytorch==1.9.1
tqdm
netcal==1.1.3
netcal==1.3.5
deepobs==1.1.2
10 changes: 5 additions & 5 deletions laplace/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
REGRESSION = 'regression'
CLASSIFICATION = 'classification'

from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace
from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace
from laplace.baselaplace import BaseLaplace, ParametricLaplace, FullLaplace, KronLaplace, DiagLaplace, LowRankLaplace, FunctionalLaplace
from laplace.lllaplace import LLLaplace, FullLLLaplace, KronLLLaplace, DiagLLLaplace, FunctionalLLLaplace
from laplace.subnetlaplace import SubnetLaplace, FullSubnetLaplace, DiagSubnetLaplace
from laplace.laplace import Laplace
from laplace.marglik_training import marglik_training

__all__ = ['Laplace', # direct access to all Laplace classes via unified interface
'BaseLaplace', 'ParametricLaplace', # base-class and its (first-level) subclasses
'BaseLaplace', 'ParametricLaplace', 'FunctionalLaplace', # base-class and its (first-level) subclasses
'FullLaplace', 'KronLaplace', 'DiagLaplace', 'LowRankLaplace', # all-weights
'LLLaplace', # base-class last-layer
'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', # last-layer
'SubnetLaplace', # base-class subnetwork
'FullLLLaplace', 'KronLLLaplace', 'DiagLLLaplace', 'FunctionalLLLaplace', # last-layer
'SubnetLaplace', # subnetwork
'FullSubnetLaplace', 'DiagSubnetLaplace', # subnetwork
'marglik_training'] # methods
Loading