From e03ecd4f474bec099af56e415dcbfb178164c080 Mon Sep 17 00:00:00 2001 From: YunLiu <55491388+KumoLiu@users.noreply.github.com> Date: Fri, 4 Nov 2022 22:29:44 +0800 Subject: [PATCH] fix `GeneralizedDiceLoss` (#5468) Signed-off-by: KumoLiu Fixes #5466. ### Description In `GeneralizedDiceLoss`, now it do channel reduction before reduction, remove channel reduction in this PR. https://github.com/Project-MONAI/MONAI/blob/c38d503a587f1779914bd071a1b2d66a6d9080c2/monai/losses/dice.py#L360-L363 ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: KumoLiu --- monai/losses/dice.py | 5 ++--- tests/test_generalized_dice_loss.py | 12 ++++++------ 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/monai/losses/dice.py b/monai/losses/dice.py index 827b87c2e1..1af4c519b3 100644 --- a/monai/losses/dice.py +++ b/monai/losses/dice.py @@ -357,9 +357,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor: max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1) w = w + infs * max_values - final_reduce_dim = 0 if self.batch else 1 - numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr - denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr + numer = 2.0 * (intersection * w) + self.smooth_nr + denom = (denominator * w) + self.smooth_dr f: torch.Tensor = 1.0 - (numer / denom) if self.reduction == LossReduction.MEAN.value: diff --git a/tests/test_generalized_dice_loss.py b/tests/test_generalized_dice_loss.py index 81f8f4c0b0..619814037b 100644 --- a/tests/test_generalized_dice_loss.py +++ b/tests/test_generalized_dice_loss.py @@ -46,7 +46,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.469964, + 0.435035, ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -54,7 +54,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.414507, + 0.3837, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -69,7 +69,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - 0.829015, + 1.5348, ], [ # shape: (2, 2, 3), (2, 1, 3) { @@ -84,7 +84,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - [[[0.273476]], [[0.555539]]], + [[[0.210949], [0.295351]], [[0.599976], [0.428522]]], ], [ # shape: (2, 2, 3), (2, 1, 3) {"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8}, @@ -112,7 +112,7 @@ "input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1, 1, 0, 0]]]), }, - 0.250023, + 0.26669, ], [ # shape: (2, 1, 2, 2), (2, 1, 2, 2) {"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4}, @@ -134,7 +134,7 @@ "input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]), "target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]), }, - -0.097833, + -8.55485, ], ]