Skip to content

Commit

Permalink
✨ Improve coverage and code quality
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Oct 15, 2023
1 parent 79e54d3 commit 068929a
Show file tree
Hide file tree
Showing 15 changed files with 169 additions and 26 deletions.
4 changes: 2 additions & 2 deletions tests/_dummies/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@
# fmt: on
class DummyClassificationDataModule(LightningDataModule):
num_channels = 1
image_size: int = 8
image_size: int = 4
training_task = "classification"

def __init__(
self,
root: Union[str, Path],
ood_detection: bool,
batch_size: int,
num_classes: int = 10,
num_classes: int = 2,
num_workers: int = 1,
pin_memory: bool = True,
persistent_workers: bool = True,
Expand Down
25 changes: 25 additions & 0 deletions tests/baselines/test_deep_ensembles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# fmt:off
from argparse import ArgumentParser

from torch_uncertainty.baselines import DeepEnsembles


# fmt:on
class TestDeepEnsembles:
"""Testing the Deep Ensembles baseline class."""

def test_standard(self):
DeepEnsembles(
task="classification",
log_path=".",
checkpoint_ids=[],
backbone="resnet",
in_channels=3,
num_classes=10,
version="vanilla",
arch=18,
style="cifar",
groups=1,
)
parser = ArgumentParser()
DeepEnsembles.add_model_specific_args(parser)
4 changes: 3 additions & 1 deletion tests/datamodules/test_uci_regression_datamodule.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# fmt:off
from argparse import ArgumentParser
from functools import partial

from torch_uncertainty.datamodules import UCIDataModule

Expand All @@ -18,8 +19,9 @@ def test_UCIRegression(self):

dm = UCIDataModule(dataset_name="kin8nm", **vars(args))

dm.dataset = DummyRegressionDataset
dm.dataset = partial(DummyRegressionDataset, num_samples=64)
dm.prepare_data()
dm.val_split = 0.5
dm.setup()
dm.setup("test")

Expand Down
4 changes: 4 additions & 0 deletions tests/layers/test_packed_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ def test_conv_one_estimator(self, seq_input: torch.Tensor):
layer = PackedConv1d(6, 2, alpha=1, num_estimators=1, kernel_size=1)
out = layer(seq_input)
assert out.shape == torch.Size([5, 2, 3])
assert layer.weight.shape == torch.Size([2, 6, 1])
assert layer.bias.shape == torch.Size([2])

def test_conv_two_estimators(self, seq_input: torch.Tensor):
layer = PackedConv1d(6, 2, alpha=1, num_estimators=2, kernel_size=1)
Expand Down Expand Up @@ -210,6 +212,8 @@ def test_conv_one_estimator(self, voxels_input: torch.Tensor):
layer = PackedConv3d(6, 2, alpha=1, num_estimators=1, kernel_size=1)
out = layer(voxels_input)
assert out.shape == torch.Size([5, 2, 3, 3, 3])
assert layer.weight.shape == torch.Size([2, 6, 1, 1, 1])
assert layer.bias.shape == torch.Size([2])

def test_conv_two_estimators(self, voxels_input: torch.Tensor):
layer = PackedConv3d(6, 2, alpha=1, num_estimators=2, kernel_size=1)
Expand Down
2 changes: 1 addition & 1 deletion tests/metrics/test_brier_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def vec2D_5classes_target1D() -> torch.Tensor:
@pytest.fixture
def vec3D() -> torch.Tensor:
"""
Return a torch tensor with a mean BrierScore of 0 and an BrierScore of
Return a torch tensor with a mean BrierScore of 0 and a BrierScore of
the mean of 0.5 to test the `ensemble` parameter of `BrierScore`.
"""
vec = torch.as_tensor([[0.0, 1.0], [1.0, 0.0]])
Expand Down
76 changes: 74 additions & 2 deletions tests/routines/test_classification.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
# fmt:off
from functools import partial
from pathlib import Path

import pytest
from cli_test_helpers import ArgvContext
from torch import nn

from torch_uncertainty import cli_main, init_args
from torch_uncertainty.losses import ELBOLoss
from torch_uncertainty.optimization_procedures import optim_cifar10_resnet18
from torch_uncertainty.routines.classification import (
ClassificationEnsemble,
Expand Down Expand Up @@ -46,6 +48,31 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)

# datamodule
args.root = str(root / "data")
dm = DummyClassificationDataModule(**vars(args))
loss = partial(
ELBOLoss,
criterion=nn.CrossEntropyLoss(),
kl_weight=1e-5,
num_samples=2,
)
model = DummyClassificationBaseline(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=loss,
optimization_procedure=optim_cifar10_resnet18,
baseline_type="single",
**vars(args),
)

cli_main(model, dm, root, "dummy", args)

with ArgvContext("file.py", "--evaluate_ood", "--entropy"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
Expand Down Expand Up @@ -89,11 +116,16 @@ def test_cli_main_dummy_binary(self):
# datamodule
args.root = str(root / "data")
dm = DummyClassificationDataModule(num_classes=1, **vars(args))

loss = partial(
ELBOLoss,
criterion=nn.CrossEntropyLoss(),
kl_weight=1e-5,
num_samples=1,
)
model = DummyClassificationBaseline(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.BCEWithLogitsLoss,
loss=loss,
optimization_procedure=optim_cifar10_resnet18,
baseline_type="ensemble",
**vars(args),
Expand Down Expand Up @@ -123,6 +155,46 @@ def test_cli_main_dummy_binary(self):

def test_cli_main_dummy_ood(self):
root = Path(__file__).parent.absolute().parents[0]
with ArgvContext("file.py", "--logits"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)

# datamodule
args.root = str(root / "data")
dm = DummyClassificationDataModule(**vars(args))

model = DummyClassificationBaseline(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=optim_cifar10_resnet18,
baseline_type="ensemble",
**vars(args),
)

cli_main(model, dm, root, "dummy", args)

with ArgvContext("file.py", "--evaluate_ood", "--entropy"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
)

# datamodule
args.root = str(root / "data")
dm = DummyClassificationDataModule(**vars(args))

model = DummyClassificationBaseline(
num_classes=dm.num_classes,
in_channels=dm.num_channels,
loss=nn.CrossEntropyLoss,
optimization_procedure=optim_cifar10_resnet18,
baseline_type="ensemble",
**vars(args),
)

cli_main(model, dm, root, "dummy", args)

with ArgvContext("file.py", "--evaluate_ood", "--variation_ratio"):
args = init_args(
DummyClassificationBaseline, DummyClassificationDataModule
Expand Down
20 changes: 20 additions & 0 deletions tests/test_losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,23 @@ def test_main(self):
assert loss(*inputs.split(1, dim=-1), targets) == pytest.approx(
2 * math.log(2)
)

loss = NIGLoss(
reg_weight=1e-2,
reduction="sum",
)

assert loss(
*inputs.repeat(2, 1).split(1, dim=-1),
targets.repeat(2, 1),
) == pytest.approx(4 * math.log(2))

loss = NIGLoss(
reg_weight=1e-2,
reduction="none",
)

assert loss(
*inputs.repeat(2, 1).split(1, dim=-1),
targets.repeat(2, 1),
) == pytest.approx([2 * math.log(2), 2 * math.log(2)])
4 changes: 2 additions & 2 deletions torch_uncertainty/baselines/deep_ensembles.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def __new__(
cls,
task: Literal["classification", "regression"],
log_path: Union[str, Path],
versions: List[int],
checkpoint_ids: List[int],
backbone: Literal["mlp", "resnet", "vgg", "wideresnet"],
# num_estimators: int,
in_channels: Optional[int] = None,
Expand All @@ -43,7 +43,7 @@ def __new__(
backbone_cls = cls.backbones[backbone]

models = []
for version in versions:
for version in checkpoint_ids:
ckpt_file, hparams_file = get_version(
root=log_path, version=version
)
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/datamodules/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,8 +153,8 @@ def setup(self, stage: Optional[str] = None) -> None:
self.train, self.val = random_split(
full,
[
int(len(full) * (1 - self.val_split)),
len(full) - int(len(full) * (1 - self.val_split)),
1 - self.val_split,
self.val_split,
],
)
if self.val_split == 0:
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/datamodules/cifar100.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,8 +154,8 @@ def setup(self, stage: Optional[str] = None) -> None:
self.train, self.val = random_split(
full,
[
int(len(full) * (1 - self.val_split)),
len(full) - int(len(full) * (1 - self.val_split)),
1 - self.val_split,
self.val_split,
],
)
if self.val_split == 0:
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/datamodules/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,8 +121,8 @@ def setup(self, stage: Optional[str] = None) -> None:
self.train, self.val = random_split(
full,
[
int(len(full) * (1 - self.val_split)),
len(full) - int(len(full) * (1 - self.val_split)),
1 - self.val_split,
self.val_split,
],
)
if self.val_split == 0:
Expand Down
8 changes: 3 additions & 5 deletions torch_uncertainty/datamodules/uci_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,11 +78,9 @@ def setup(self, stage: Optional[str] = None) -> None:
self.train, self.test, self.val = random_split(
full,
[
int(len(full) * (0.8 - self.val_split)),
int(len(full) * 0.2),
len(full)
- int(len(full) * 0.2)
- int(len(full) * (0.8 - self.val_split)),
0.8 - self.val_split,
0.2,
self.val_split,
],
generator=self.gen,
)
Expand Down
28 changes: 25 additions & 3 deletions torch_uncertainty/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,12 @@ def __init__(
super().__init__()
self.model = model
self._kl_div = KLDiv(model)

if isinstance(criterion, type):
raise ValueError(
"The criterion should be an instance of a class."
f"Got {criterion}."
)
self.criterion = criterion

if kl_weight < 0:
Expand Down Expand Up @@ -123,7 +129,14 @@ def __init__(
raise ValueError(f"{reduction} is not a valid value for reduction.")
self.reduction = reduction

def _nig_nll(self, gamma, v, alpha, beta, targets):
def _nig_nll(
self,
gamma: Tensor,
v: Tensor,
alpha: Tensor,
beta: Tensor,
targets: Tensor,
) -> Tensor:
Gamma = 2 * beta * (1 + v)
nll = (
0.5 * torch.log(torch.pi / v)
Expand All @@ -134,13 +147,22 @@ def _nig_nll(self, gamma, v, alpha, beta, targets):
)
return nll

def _nig_reg(self, gamma, v, alpha, targets):
def _nig_reg(
self, gamma: Tensor, v: Tensor, alpha: Tensor, targets: Tensor
) -> Tensor:
reg = torch.norm(targets - gamma, 1, dim=1, keepdim=True) * (
2 * v + alpha
)
return reg

def forward(self, gamma, v, alpha, beta, targets):
def forward(
self,
gamma: Tensor,
v: Tensor,
alpha: Tensor,
beta: Tensor,
targets: Tensor,
) -> Tensor:
loss_nll = self._nig_nll(gamma, v, alpha, beta, targets)
loss_reg = self._nig_reg(gamma, v, alpha, targets)
loss = loss_nll + self.reg_weight * loss_reg
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/models/vgg/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,9 +113,9 @@ def _init_weights(self):
)
if m.bias is not None: # coverage: ignore
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, nn.BatchNorm2d): # coverage: ignore
nn.init.constant_(m.weight, 1)
if m.bias is not None: # coverage: ignore
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear) or isinstance(m, PackedLinear):
nn.init.normal_(m.weight, 0, 0.01)
Expand Down
4 changes: 2 additions & 2 deletions torch_uncertainty/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def __init__(
self.fill = fill

def forward(
self, img: Union[Tensor, Image.Image], level: float
self, img: Union[Tensor, Image.Image], level: int
) -> Union[Tensor, Image.Image]:
if (
self.random_direction and np.random.uniform() > 0.5
Expand Down Expand Up @@ -213,7 +213,7 @@ def __init__(self):
super().__init__()

def forward(
self, img: Union[Tensor, Image.Image], level=float
self, img: Union[Tensor, Image.Image], level: float
) -> Union[Tensor, Image.Image]:
if level < 0:
raise ValueError("Level must be greater than 0.")
Expand Down

0 comments on commit 068929a

Please sign in to comment.