diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 827b87c2e1..1af4c519b3 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -357,9 +357,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1) w = w + infs * max_values - final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr - denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr + numer = 2.0 * (intersection * w) + self.smooth_nr + denom = (denominator * w) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 81f8f4c0b0..619814037b 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -46,7 +46,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.469964, + 0.435035, ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -54,7 +54,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.414507, + 0.3837, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -69,7 +69,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.829015, + 1.5348, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -84,7 +84,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[[0.273476]], [[0.555539]]], + [[[0.210949], [0.295351]], [[0.599976], [0.428522]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8}, @@ -112,7 +112,7 @@ "input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1, 1, 0, 0]]]), }, - 0.250023, + 0.26669, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -134,7 +134,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - -0.097833, + -8.55485, ], ]