From 7bfb46ccee1deb0edd78950685c3b44abedfecf7 Mon Sep 17 00:00:00 2001 From: ytl0623 Date: Fri, 9 Aug 2024 14:35:57 +0800 Subject: [PATCH] Add label smoothing param in DiceCELoss (#8000) Fixes #7957 ### Description In this modified version I made the following changes: 1. Added `label_smoothing: float = 0.0` parameter in `__init__` method, default value is 0.0. 2. When creating the `self.cross_entropy` instance, pass the `label_smoothing` parameter to `nn.CrossEntropyLoss`. 3. Added `self.label_smoothing = label_smoothing` in the `__init__` method to save this parameter for access when needed. For example: ``` from monai.losses import DiceCELoss # Before criterion = DiceCELoss() criterion.cross_entropy.label_smoothing = 0.1 # Now criterion = DiceCELoss(label_smoothing=0.1) ``` ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: ytl0623 Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> --- monai/losses/dice.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 07a38d9572..44cde41e5d 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -666,6 +666,7 @@ def __init__( weight: torch.Tensor | None = None, lambda_dice: float = 1.0, lambda_ce: float = 1.0, + label_smoothing: float = 0.0, ) -> None: """ Args: @@ -704,6 +705,9 @@ def __init__( Defaults to 1.0. lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0. Defaults to 1.0. + label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed + by the given factor to reduce overfitting. + Defaults to 0.0. """ super().__init__() @@ -728,7 +732,12 @@ def __init__( batch=batch, weight=dice_weight, ) - self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction) + if pytorch_after(1, 10): + self.cross_entropy = nn.CrossEntropyLoss( + weight=weight, reduction=reduction, label_smoothing=label_smoothing + ) + else: + self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction) self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction) if lambda_dice < 0.0: raise ValueError("lambda_dice should be no less than 0.0.")