Skip to content

Commit

Permalink
Add label smoothing param in DiceCELoss (Project-MONAI#8000)
Browse files Browse the repository at this point in the history
Fixes Project-MONAI#7957 

### Description
In this modified version I made the following changes:
1. Added `label_smoothing: float = 0.0` parameter in `__init__` method,
default value is 0.0.
2. When creating the `self.cross_entropy` instance, pass the
`label_smoothing` parameter to `nn.CrossEntropyLoss`.
3. Added `self.label_smoothing = label_smoothing` in the `__init__`
method to save this parameter for access when needed.

For example:
```
from monai.losses import DiceCELoss

# Before
criterion = DiceCELoss()
criterion.cross_entropy.label_smoothing = 0.1

# Now
criterion = DiceCELoss(label_smoothing=0.1)
```

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

---------

Signed-off-by: ytl0623 <david89062388@gmail.com>
Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
  • Loading branch information
2 people authored and rcremese committed Sep 2, 2024
1 parent 231c585 commit 7bfb46c
Showing 1 changed file with 10 additions and 1 deletion.
11 changes: 10 additions & 1 deletion monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,7 @@ def __init__(
weight: torch.Tensor | None = None,
lambda_dice: float = 1.0,
lambda_ce: float = 1.0,
label_smoothing: float = 0.0,
) -> None:
"""
Args:
Expand Down Expand Up @@ -704,6 +705,9 @@ def __init__(
Defaults to 1.0.
lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
Defaults to 1.0.
label_smoothing: a value in [0, 1] range. If > 0, the labels are smoothed
by the given factor to reduce overfitting.
Defaults to 0.0.
"""
super().__init__()
Expand All @@ -728,7 +732,12 @@ def __init__(
batch=batch,
weight=dice_weight,
)
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
if pytorch_after(1, 10):
self.cross_entropy = nn.CrossEntropyLoss(
weight=weight, reduction=reduction, label_smoothing=label_smoothing
)
else:
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, reduction=reduction)
self.binary_cross_entropy = nn.BCEWithLogitsLoss(pos_weight=weight, reduction=reduction)
if lambda_dice < 0.0:
raise ValueError("lambda_dice should be no less than 0.0.")
Expand Down

0 comments on commit 7bfb46c

Please sign in to comment.