Skip to content

Commit

Permalink
review of unittests
Browse files Browse the repository at this point in the history
  • Loading branch information
thibaultdvx committed Sep 19, 2024
1 parent 068e7ed commit 00b7127
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 34 deletions.
122 changes: 94 additions & 28 deletions tests/unittests/losses/test_config.py
Original file line number Diff line number Diff line change
@@ -1,35 +1,101 @@
import pytest
import torch
from pydantic import ValidationError

from clinicadl.losses import LossConfig
from clinicadl.losses import ImplementedLoss
from clinicadl.losses.config import (
BCEConfig,
BCEWithLogitsConfig,
CrossEntropyConfig,
HuberConfig,
KLDivConfig,
L1Config,
MSEConfig,
MultiMarginConfig,
NLLConfig,
SmoothL1Config,
create_loss_config,
)


def test_LossConfig():
config = LossConfig(
loss="SmoothL1Loss", margin=10.0, delta=2.0, reduction="none", weight=None
@pytest.mark.parametrize(
"config,args",
[
(L1Config, {"reduction": "none"}),
(MSEConfig, {"reduction": "none"}),
(CrossEntropyConfig, {"reduction": "none"}),
(CrossEntropyConfig, {"weight": [1, -1, 2]}),
(CrossEntropyConfig, {"ignore_index": -1}),
(CrossEntropyConfig, {"label_smoothing": 1.1}),
(NLLConfig, {"reduction": "none"}),
(NLLConfig, {"weight": [1, -1, 2]}),
(NLLConfig, {"ignore_index": -1}),
(KLDivConfig, {"reduction": "none"}),
(BCEConfig, {"reduction": "none"}),
(BCEConfig, {"weight": [0, 1]}),
(BCEWithLogitsConfig, {"reduction": "none"}),
(BCEWithLogitsConfig, {"weight": [0, 1]}),
(BCEWithLogitsConfig, {"pos_weight": [[1, -1, 2]]}),
(BCEWithLogitsConfig, {"pos_weight": [["a", "b"]]}),
(HuberConfig, {"reduction": "none"}),
(HuberConfig, {"delta": 0.0}),
(SmoothL1Config, {"reduction": "none"}),
(SmoothL1Config, {"beta": -1.0}),
(MultiMarginConfig, {"reduction": "none"}),
(MultiMarginConfig, {"p": 3}),
(MultiMarginConfig, {"weight": [1, -1, 2]}),
],
)
def test_validation_fail(config, args):
with pytest.raises((ValidationError, ValueError)):
config(**args)


@pytest.mark.parametrize(
"config,args",
[
(L1Config, {"reduction": "mean"}),
(MSEConfig, {"reduction": "mean"}),
(
CrossEntropyConfig,
{
"reduction": "mean",
"weight": [1, 0, 2],
"ignore_index": 1,
"label_smoothing": 0.5,
},
),
(NLLConfig, {"reduction": "mean", "weight": [1, 0, 2], "ignore_index": 1}),
(KLDivConfig, {"reduction": "mean", "log_target": True}),
(BCEConfig, {"reduction": "sum", "weight": None}),
(
BCEWithLogitsConfig,
{"reduction": "sum", "weight": None, "pos_weight": [[1, 0, 2]]},
),
(HuberConfig, {"reduction": "sum", "delta": 0.1}),
(SmoothL1Config, {"reduction": "sum", "beta": 0.0}),
(
MultiMarginConfig,
{"reduction": "sum", "p": 1, "margin": -0.1, "weight": [1, 0, 2]},
),
],
)
def test_validation_pass(config, args):
c = config(**args)
for arg, value in args.items():
assert getattr(c, arg) == value


def test_create_loss_config():
for loss in ImplementedLoss:
create_loss_config(loss)

config_class = create_loss_config("Multi Margin")
config = config_class(
margin=0.1,
reduction="sum",
)
assert config.loss == "SmoothL1Loss"
assert config.margin == 10.0
assert config.delta == 2.0
assert config.reduction == "none"
assert isinstance(config, MultiMarginConfig)
assert config.p == "DefaultFromLibrary"

with pytest.raises(ValueError):
LossConfig(loss="abc")
with pytest.raises(ValueError):
LossConfig(weight=[0.1, -0.1, 0.8])
with pytest.raises(ValueError):
LossConfig(p=3)
with pytest.raises(ValueError):
LossConfig(reduction="abc")
with pytest.raises(ValidationError):
LossConfig(label_smoothing=1.1)
with pytest.raises(ValidationError):
LossConfig(ignore_index=-1)
with pytest.raises(ValidationError):
LossConfig(loss="BCEWithLogitsLoss", weight=[1, 2, 3])
with pytest.raises(ValidationError):
LossConfig(loss="BCELoss", weight=[1, 2, 3])

LossConfig(loss="BCELoss")
LossConfig(loss="BCEWithLogitsLoss", weight=None)
assert config.margin == 0.1
assert config.reduction == "sum"
12 changes: 6 additions & 6 deletions tests/unittests/losses/test_factory.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from torch import Tensor
from torch.nn import BCEWithLogitsLoss, MultiMarginLoss

from clinicadl.losses import ImplementedLoss, LossConfig, get_loss_function
from clinicadl.losses import ImplementedLoss, create_loss_config, get_loss_function


def test_get_loss_function():
for loss in [e.value for e in ImplementedLoss]:
config = LossConfig(loss=loss)
get_loss_function(config)
for loss in ImplementedLoss:
config = create_loss_config(loss=loss)()
_ = get_loss_function(config)

config = LossConfig(loss="MultiMarginLoss", reduction="sum", weight=[1, 2, 3], p=2)
config = create_loss_config("Multi Margin")(reduction="sum", weight=[1, 2, 3], p=2)
loss, updated_config = get_loss_function(config)
assert isinstance(loss, MultiMarginLoss)
assert loss.reduction == "sum"
Expand All @@ -23,7 +23,7 @@ def test_get_loss_function():
assert updated_config.margin == 1.0
assert updated_config.weight == [1, 2, 3]

config = LossConfig(loss="BCEWithLogitsLoss", pos_weight=[1, 2, 3])
config = create_loss_config("BCE With Logits")(pos_weight=[1, 2, 3])
loss, updated_config = get_loss_function(config)
assert isinstance(loss, BCEWithLogitsLoss)
assert (loss.pos_weight == Tensor([1, 2, 3])).all()
Expand Down

0 comments on commit 00b7127

Please sign in to comment.