diff --git a/tests/unittests/losses/test_config.py b/tests/unittests/losses/test_config.py index cb00adbf3..49372baf1 100644 --- a/tests/unittests/losses/test_config.py +++ b/tests/unittests/losses/test_config.py @@ -1,35 +1,101 @@ import pytest +import torch from pydantic import ValidationError -from clinicadl.losses import LossConfig +from clinicadl.losses import ImplementedLoss +from clinicadl.losses.config import ( + BCEConfig, + BCEWithLogitsConfig, + CrossEntropyConfig, + HuberConfig, + KLDivConfig, + L1Config, + MSEConfig, + MultiMarginConfig, + NLLConfig, + SmoothL1Config, + create_loss_config, +) -def test_LossConfig(): - config = LossConfig( - loss="SmoothL1Loss", margin=10.0, delta=2.0, reduction="none", weight=None +@pytest.mark.parametrize( + "config,args", + [ + (L1Config, {"reduction": "none"}), + (MSEConfig, {"reduction": "none"}), + (CrossEntropyConfig, {"reduction": "none"}), + (CrossEntropyConfig, {"weight": [1, -1, 2]}), + (CrossEntropyConfig, {"ignore_index": -1}), + (CrossEntropyConfig, {"label_smoothing": 1.1}), + (NLLConfig, {"reduction": "none"}), + (NLLConfig, {"weight": [1, -1, 2]}), + (NLLConfig, {"ignore_index": -1}), + (KLDivConfig, {"reduction": "none"}), + (BCEConfig, {"reduction": "none"}), + (BCEConfig, {"weight": [0, 1]}), + (BCEWithLogitsConfig, {"reduction": "none"}), + (BCEWithLogitsConfig, {"weight": [0, 1]}), + (BCEWithLogitsConfig, {"pos_weight": [[1, -1, 2]]}), + (BCEWithLogitsConfig, {"pos_weight": [["a", "b"]]}), + (HuberConfig, {"reduction": "none"}), + (HuberConfig, {"delta": 0.0}), + (SmoothL1Config, {"reduction": "none"}), + (SmoothL1Config, {"beta": -1.0}), + (MultiMarginConfig, {"reduction": "none"}), + (MultiMarginConfig, {"p": 3}), + (MultiMarginConfig, {"weight": [1, -1, 2]}), + ], +) +def test_validation_fail(config, args): + with pytest.raises((ValidationError, ValueError)): + config(**args) + + +@pytest.mark.parametrize( + "config,args", + [ + (L1Config, {"reduction": "mean"}), + (MSEConfig, {"reduction": "mean"}), + ( + CrossEntropyConfig, + { + "reduction": "mean", + "weight": [1, 0, 2], + "ignore_index": 1, + "label_smoothing": 0.5, + }, + ), + (NLLConfig, {"reduction": "mean", "weight": [1, 0, 2], "ignore_index": 1}), + (KLDivConfig, {"reduction": "mean", "log_target": True}), + (BCEConfig, {"reduction": "sum", "weight": None}), + ( + BCEWithLogitsConfig, + {"reduction": "sum", "weight": None, "pos_weight": [[1, 0, 2]]}, + ), + (HuberConfig, {"reduction": "sum", "delta": 0.1}), + (SmoothL1Config, {"reduction": "sum", "beta": 0.0}), + ( + MultiMarginConfig, + {"reduction": "sum", "p": 1, "margin": -0.1, "weight": [1, 0, 2]}, + ), + ], +) +def test_validation_pass(config, args): + c = config(**args) + for arg, value in args.items(): + assert getattr(c, arg) == value + + +def test_create_loss_config(): + for loss in ImplementedLoss: + create_loss_config(loss) + + config_class = create_loss_config("Multi Margin") + config = config_class( + margin=0.1, + reduction="sum", ) - assert config.loss == "SmoothL1Loss" - assert config.margin == 10.0 - assert config.delta == 2.0 - assert config.reduction == "none" + assert isinstance(config, MultiMarginConfig) assert config.p == "DefaultFromLibrary" - - with pytest.raises(ValueError): - LossConfig(loss="abc") - with pytest.raises(ValueError): - LossConfig(weight=[0.1, -0.1, 0.8]) - with pytest.raises(ValueError): - LossConfig(p=3) - with pytest.raises(ValueError): - LossConfig(reduction="abc") - with pytest.raises(ValidationError): - LossConfig(label_smoothing=1.1) - with pytest.raises(ValidationError): - LossConfig(ignore_index=-1) - with pytest.raises(ValidationError): - LossConfig(loss="BCEWithLogitsLoss", weight=[1, 2, 3]) - with pytest.raises(ValidationError): - LossConfig(loss="BCELoss", weight=[1, 2, 3]) - - LossConfig(loss="BCELoss") - LossConfig(loss="BCEWithLogitsLoss", weight=None) + assert config.margin == 0.1 + assert config.reduction == "sum" diff --git a/tests/unittests/losses/test_factory.py b/tests/unittests/losses/test_factory.py index 5ac786deb..5396bac3d 100644 --- a/tests/unittests/losses/test_factory.py +++ b/tests/unittests/losses/test_factory.py @@ -1,15 +1,15 @@ from torch import Tensor from torch.nn import BCEWithLogitsLoss, MultiMarginLoss -from clinicadl.losses import ImplementedLoss, LossConfig, get_loss_function +from clinicadl.losses import ImplementedLoss, create_loss_config, get_loss_function def test_get_loss_function(): - for loss in [e.value for e in ImplementedLoss]: - config = LossConfig(loss=loss) - get_loss_function(config) + for loss in ImplementedLoss: + config = create_loss_config(loss=loss)() + _ = get_loss_function(config) - config = LossConfig(loss="MultiMarginLoss", reduction="sum", weight=[1, 2, 3], p=2) + config = create_loss_config("Multi Margin")(reduction="sum", weight=[1, 2, 3], p=2) loss, updated_config = get_loss_function(config) assert isinstance(loss, MultiMarginLoss) assert loss.reduction == "sum" @@ -23,7 +23,7 @@ def test_get_loss_function(): assert updated_config.margin == 1.0 assert updated_config.weight == [1, 2, 3] - config = LossConfig(loss="BCEWithLogitsLoss", pos_weight=[1, 2, 3]) + config = create_loss_config("BCE With Logits")(pos_weight=[1, 2, 3]) loss, updated_config = get_loss_function(config) assert isinstance(loss, BCEWithLogitsLoss) assert (loss.pos_weight == Tensor([1, 2, 3])).all()