@@ -614,8 +614,8 @@ class DiceCELoss(_Loss):
614614 """
615615 Compute both Dice loss and Cross Entropy Loss, and return the weighted sum of these two losses.
616616 The details of Dice loss is shown in ``monai.losses.DiceLoss``.
617- The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss``. In this implementation,
618- two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
617+ The details of Cross Entropy Loss is shown in ``torch.nn.CrossEntropyLoss`` and ``torch.nn.BCEWithLogitsLoss()``.
618+ In this implementation, two deprecated parameters ``size_average`` and ``reduce``, and the parameter ``ignore_index`` are
619619 not supported.
620620
621621 """
@@ -646,11 +646,11 @@ def __init__(
646646 to_onehot_y: whether to convert the ``target`` into the one-hot format,
647647 using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
648648 sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
649- don't need to specify activation function for `CrossEntropyLoss`.
649+ don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss` .
650650 softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
651- don't need to specify activation function for `CrossEntropyLoss`.
651+ don't need to specify activation function for `CrossEntropyLoss` and `BCEWithLogitsLoss` .
652652 other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
653- ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
653+ ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss` and `BCEWithLogitsLoss` .
654654 squared_pred: use squared versions of targets and predictions in the denominator or not.
655655 jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
656656 reduction: {``"mean"``, ``"sum"``}
@@ -666,8 +666,9 @@ def __init__(
666666 batch: whether to sum the intersection and union areas over the batch dimension before the dividing.
667667 Defaults to False, a Dice loss value is computed independently from each item in the batch
668668 before any `reduction`.
669- ce_weight: a rescaling weight given to each class for cross entropy loss.
670- See ``torch.nn.CrossEntropyLoss()`` for more information.
669+ ce_weight: a rescaling weight given to each class for cross entropy loss for `CrossEntropyLoss`.
670+ or a rescaling weight given to the loss of each batch element for `BCEWithLogitsLoss`.
671+ See ``torch.nn.CrossEntropyLoss()`` or ``torch.nn.BCEWithLogitsLoss()`` for more information.
671672 lambda_dice: the trade-off weight value for dice loss. The value should be no less than 0.0.
672673 Defaults to 1.0.
673674 lambda_ce: the trade-off weight value for cross entropy loss. The value should be no less than 0.0.
@@ -690,6 +691,7 @@ def __init__(
690691 batch = batch ,
691692 )
692693 self .cross_entropy = nn .CrossEntropyLoss (weight = ce_weight , reduction = reduction )
694+ self .binary_cross_entropy = nn .BCEWithLogitsLoss (weight = ce_weight , reduction = reduction )
693695 if lambda_dice < 0.0 :
694696 raise ValueError ("lambda_dice should be no less than 0.0." )
695697 if lambda_ce < 0.0 :
@@ -700,7 +702,7 @@ def __init__(
700702
701703 def ce (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
702704 """
703- Compute CrossEntropy loss for the input and target.
705+ Compute CrossEntropy loss for the input logits and target.
704706 Will remove the channel dim according to PyTorch CrossEntropyLoss:
705707 https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html?#torch.nn.CrossEntropyLoss.
706708
@@ -720,6 +722,16 @@ def ce(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
720722
721723 return self .cross_entropy (input , target ) # type: ignore[no-any-return]
722724
725+ def bce (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
726+ """
727+ Compute Binary CrossEntropy loss for the input logits and target in one single class.
728+
729+ """
730+ if not torch .is_floating_point (target ):
731+ target = target .to (dtype = input .dtype )
732+
733+ return self .binary_cross_entropy (input , target ) # type: ignore[no-any-return]
734+
723735 def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
724736 """
725737 Args:
@@ -738,7 +750,7 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
738750 )
739751
740752 dice_loss = self .dice (input , target )
741- ce_loss = self .ce (input , target )
753+ ce_loss = self .ce (input , target ) if input . shape [ 1 ] != 1 else self . bce ( input , target )
742754 total_loss : torch .Tensor = self .lambda_dice * dice_loss + self .lambda_ce * ce_loss
743755
744756 return total_loss
0 commit comments