Skip to content

Commit 50e66fa

Browse files
authored
Add support for BCEWithLogitsLoss in DiceCELoss (#6924)
Fixes #6923. ### Description Add support for `BCEWithLogitsLoss` for single-class. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: KumoLiu <yunl@nvidia.com>
1 parent 3b27bb6 commit 50e66fa

File tree

3 files changed

+31
-11
lines changed

3 files changed

+31
-11
lines changed

monai/losses/dice.py

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -614,8 +614,8 @@ class DiceCELoss(_Loss):
614614
"""
615615
Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses.
616616
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
617-
The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss``. In this implementation,
618-
two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
617+
The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss`` and ``torch.nn.BCEWithLogitsLoss()``.
618+
In this implementation, two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
619619
not supported.
620620
621621
"""
@@ -646,11 +646,11 @@ def __init__(
646646
to_onehot_y: whether to convert the ``target`` into the one-hot format,
647647
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
648648
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
649-
don't need to specify activation function for `CrossEntropyLoss`.
649+
don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.
650650
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
651-
don't need to specify activation function for `CrossEntropyLoss`.
651+
don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss`.
652652
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
653-
``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
653+
``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss` and `BCEWithLogitsLoss`.
654654
squared_pred: use squared versions of targets and predictions in the denominator or not.
655655
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
656656
reduction: {``"mean"``, ``"sum"``}
@@ -666,8 +666,9 @@ def __init__(
666666
batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
667667
Defaults to False, a Dice loss value is computed independently from each item in the batch
668668
before any `reduction`.
669-
ce_weight: a rescaling weight given to each class for cross entropy loss.
670-
See ``torch.nn.CrossEntropyLoss()`` for more information.
669+
ce_weight: a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`.
670+
or a rescaling weight given to the loss of each batch element for `BCEWithLogitsLoss`.
671+
See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information.
671672
lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
672673
Defaults to 1.0.
673674
lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
@@ -690,6 +691,7 @@ def __init__(
690691
batch=batch,
691692
)
692693
self.cross_entropy = nn.CrossEntropyLoss(weight=ce_weight, reduction=reduction)
694+
self.binary_cross_entropy = nn.BCEWithLogitsLoss(weight=ce_weight, reduction=reduction)
693695
if lambda_dice < 0.0:
694696
raise ValueError("lambda_dice should be no less than 0.0.")
695697
if lambda_ce < 0.0:
@@ -700,7 +702,7 @@ def __init__(
700702

701703
def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
702704
"""
703-
Compute CrossEntropy loss for the input and target.
705+
Compute CrossEntropy loss for the input logits and target.
704706
Will remove the channel dim according to PyTorch CrossEntropyLoss:
705707
https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss.
706708
@@ -720,6 +722,16 @@ def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
720722

721723
return self.cross_entropy(input, target) # type: ignore[no-any-return]
722724

725+
def bce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
726+
"""
727+
Compute Binary CrossEntropy loss for the input logits and target in one single class.
728+
729+
"""
730+
if not torch.is_floating_point(target):
731+
target = target.to(dtype=input.dtype)
732+
733+
return self.binary_cross_entropy(input, target) # type: ignore[no-any-return]
734+
723735
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
724736
"""
725737
Args:
@@ -738,7 +750,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
738750
)
739751

740752
dice_loss = self.dice(input, target)
741-
ce_loss = self.ce(input, target)
753+
ce_loss = self.ce(input, target) if input.shape[1] != 1 else self.bce(input, target)
742754
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_ce * ce_loss
743755

744756
return total_loss

tests/test_dice_ce_loss.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,14 @@
7575
},
7676
0.3133,
7777
],
78+
[ # shape: (2, 1, 3), (2, 1, 3), bceloss
79+
{"ce_weight": torch.tensor([1.0, 1.0, 1.0]), "sigmoid": True},
80+
{
81+
"input": torch.tensor([[[0.8, 0.6, 0.0]], [[0.0, 0.0, 0.9]]]),
82+
"target": torch.tensor([[[0.0, 0.0, 1.0]], [[0.0, 1.0, 0.0]]]),
83+
},
84+
1.5608,
85+
],
7886
]
7987

8088

@@ -97,7 +105,7 @@ def test_ill_reduction(self):
97105

98106
def test_script(self):
99107
loss = DiceCELoss()
100-
test_input = torch.ones(2, 1, 8, 8)
108+
test_input = torch.ones(2, 2, 8, 8)
101109
test_script_save(loss, test_input, test_input)
102110

103111

tests/test_ds_loss.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def test_ill_reduction(self):
154154
@SkipIfBeforePyTorchVersion((1, 10))
155155
def test_script(self):
156156
loss = DeepSupervisionLoss(DiceCELoss())
157-
test_input = torch.ones(2, 1, 8, 8)
157+
test_input = torch.ones(2, 2, 8, 8)
158158
test_script_save(loss, test_input, test_input)
159159

160160

0 commit comments

Comments
 (0)