@@ -60,12 +60,12 @@ def __init__(
60
60
include_background: if False, channel index 0 (background category) is excluded from the calculation.
61
61
if the non-background segmentations are small compared to the total image size they can get overwhelmed
62
62
by the signal from the background so excluding it in such cases helps convergence.
63
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
63
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
64
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
64
65
sigmoid: if True, apply a sigmoid function to the prediction.
65
66
softmax: if True, apply a softmax function to the prediction.
66
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
67
- other activation layers, Defaults to ``None``. for example:
68
- `other_act = torch.tanh`.
67
+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
68
+ ``other_act = torch.tanh``.
69
69
squared_pred: use squared versions of targets and predictions in the denominator or not.
70
70
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
71
71
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -247,12 +247,12 @@ def __init__(
247
247
"""
248
248
Args:
249
249
include_background: If False channel index 0 (background category) is excluded from the calculation.
250
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
250
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
251
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
251
252
sigmoid: If True, apply a sigmoid function to the prediction.
252
253
softmax: If True, apply a softmax function to the prediction.
253
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
254
- other activation layers, Defaults to ``None``. for example:
255
- `other_act = torch.tanh`.
254
+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
255
+ ``other_act = torch.tanh``.
256
256
w_type: {``"square"``, ``"simple"``, ``"uniform"``}
257
257
Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``.
258
258
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -639,14 +639,14 @@ def __init__(
639
639
``reduction`` is used for both losses and other parameters are only used for dice loss.
640
640
641
641
include_background: if False channel index 0 (background category) is excluded from the calculation.
642
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
642
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
643
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
643
644
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
644
645
don't need to specify activation function for `CrossEntropyLoss`.
645
646
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
646
647
don't need to specify activation function for `CrossEntropyLoss`.
647
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
648
- other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
649
- only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`.
648
+ other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
649
+ ``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
650
650
squared_pred: use squared versions of targets and predictions in the denominator or not.
651
651
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
652
652
reduction: {``"mean"``, ``"sum"``}
@@ -728,7 +728,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
728
728
729
729
"""
730
730
if len (input .shape ) != len (target .shape ):
731
- raise ValueError ("the number of dimensions for input and target should be the same." )
731
+ raise ValueError (
732
+ "the number of dimensions for input and target should be the same, "
733
+ f"got shape { input .shape } and { target .shape } ."
734
+ )
732
735
733
736
dice_loss = self .dice (input , target )
734
737
ce_loss = self .ce (input , target )
@@ -743,6 +746,10 @@ class DiceFocalLoss(_Loss):
743
746
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
744
747
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
745
748
749
+ ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for the focal loss.
750
+ ``include_background`` and ``reduction`` are used for both losses
751
+ and other parameters are only used for dice loss.
752
+
746
753
"""
747
754
748
755
def __init__ (
@@ -765,18 +772,15 @@ def __init__(
765
772
) -> None :
766
773
"""
767
774
Args:
768
- ``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss.
769
- ``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses
770
- and other parameters are only used for dice loss.
771
775
include_background: if False channel index 0 (background category) is excluded from the calculation.
772
- to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
776
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
777
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
773
778
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
774
779
don't need to specify activation function for `FocalLoss`.
775
780
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
776
781
don't need to specify activation function for `FocalLoss`.
777
- other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
778
- other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
779
- only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`.
782
+ other_act: callable function to execute other activation layers, Defaults to ``None``.
783
+ for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
780
784
squared_pred: use squared versions of targets and predictions in the denominator or not.
781
785
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
782
786
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -803,6 +807,8 @@ def __init__(
803
807
"""
804
808
super ().__init__ ()
805
809
self .dice = DiceLoss (
810
+ include_background = include_background ,
811
+ to_onehot_y = False ,
806
812
sigmoid = sigmoid ,
807
813
softmax = softmax ,
808
814
other_act = other_act ,
@@ -813,15 +819,20 @@ def __init__(
813
819
smooth_dr = smooth_dr ,
814
820
batch = batch ,
815
821
)
816
- self .focal = FocalLoss (gamma = gamma , weight = focal_weight , reduction = reduction )
822
+ self .focal = FocalLoss (
823
+ include_background = include_background ,
824
+ to_onehot_y = False ,
825
+ gamma = gamma ,
826
+ weight = focal_weight ,
827
+ reduction = reduction ,
828
+ )
817
829
if lambda_dice < 0.0 :
818
830
raise ValueError ("lambda_dice should be no less than 0.0." )
819
831
if lambda_focal < 0.0 :
820
832
raise ValueError ("lambda_focal should be no less than 0.0." )
821
833
self .lambda_dice = lambda_dice
822
834
self .lambda_focal = lambda_focal
823
835
self .to_onehot_y = to_onehot_y
824
- self .include_background = include_background
825
836
826
837
def forward (self , input : torch .Tensor , target : torch .Tensor ) -> torch .Tensor :
827
838
"""
@@ -836,24 +847,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
836
847
837
848
"""
838
849
if len (input .shape ) != len (target .shape ):
839
- raise ValueError ("the number of dimensions for input and target should be the same." )
840
-
841
- n_pred_ch = input .shape [ 1 ]
842
-
850
+ raise ValueError (
851
+ "the number of dimensions for input and target should be the same, "
852
+ f"got shape { input .shape } and { target . shape } ."
853
+ )
843
854
if self .to_onehot_y :
855
+ n_pred_ch = input .shape [1 ]
844
856
if n_pred_ch == 1 :
845
857
warnings .warn ("single channel prediction, `to_onehot_y=True` ignored." )
846
858
else :
847
859
target = one_hot (target , num_classes = n_pred_ch )
848
-
849
- if not self .include_background :
850
- if n_pred_ch == 1 :
851
- warnings .warn ("single channel prediction, `include_background=False` ignored." )
852
- else :
853
- # if skipping background, removing first channel
854
- target = target [:, 1 :]
855
- input = input [:, 1 :]
856
-
857
860
dice_loss = self .dice (input , target )
858
861
focal_loss = self .focal (input , target )
859
862
total_loss : torch .Tensor = self .lambda_dice * dice_loss + self .lambda_focal * focal_loss
@@ -867,11 +870,13 @@ class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
867
870
Args:
868
871
include_background (bool, optional): if False channel index 0 (background category) is excluded from the calculation.
869
872
Defaults to True.
870
- to_onehot_y (bool, optional): whether to convert `y` into the one-hot format. Defaults to False.
873
+ to_onehot_y: whether to convert the ``target`` into the one-hot format,
874
+ using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
871
875
sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
872
876
softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
873
- other_act (Optional[Callable], optional): if don't want to use sigmoid or softmax, use other callable
874
- function to execute other activation layers. Defaults to None.
877
+ other_act (Optional[Callable], optional): callable function to execute other activation layers,
878
+ Defaults to ``None``. for example: `other_act = torch.tanh`.
879
+ only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
875
880
w_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
876
881
ground-truth volume to a weight factor. Defaults to ``"square"``.
877
882
reduction (Union[LossReduction, str], optional): {``"none"``, ``"mean"``, ``"sum"``}. Specified the reduction to
0 commit comments