Skip to content

Commit

Permalink
fix GeneralizedDiceLoss (#5468)
Browse files Browse the repository at this point in the history
Signed-off-by: KumoLiu <yunl@nvidia.com>

Fixes #5466.

### Description
In `GeneralizedDiceLoss`, now it do channel reduction before reduction,
remove channel reduction in this PR.

https://github.com/Project-MONAI/MONAI/blob/c38d503a587f1779914bd071a1b2d66a6d9080c2/monai/losses/dice.py#L360-L363

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [x] New tests added to cover the changes.
- [ ] Integration tests passed locally by running `./runtests.sh -f -u
--net --coverage`.
- [ ] Quick tests passed locally by running `./runtests.sh --quick
--unittests --disttests`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated, tested `make html` command in the `docs/`
folder.

Signed-off-by: KumoLiu <yunl@nvidia.com>
  • Loading branch information
KumoLiu authored and wyli committed Nov 6, 2022
1 parent c92e1b3 commit 531a631
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 9 deletions.
5 changes: 2 additions & 3 deletions monai/losses/dice.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,9 +357,8 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
max_values = torch.max(w, dim=1)[0].unsqueeze(dim=1)
w = w + infs * max_values

final_reduce_dim = 0 if self.batch else 1
numer = 2.0 * (intersection * w).sum(final_reduce_dim, keepdim=True) + self.smooth_nr
denom = (denominator * w).sum(final_reduce_dim, keepdim=True) + self.smooth_dr
numer = 2.0 * (intersection * w) + self.smooth_nr
denom = (denominator * w) + self.smooth_dr
f: torch.Tensor = 1.0 - (numer / denom)

if self.reduction == LossReduction.MEAN.value:
Expand Down
12 changes: 6 additions & 6 deletions tests/test_generalized_dice_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,15 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
0.469964,
0.435035,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": True, "to_onehot_y": True, "softmax": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
{
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
0.414507,
0.3837,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
Expand All @@ -69,7 +69,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
0.829015,
1.5348,
],
[ # shape: (2, 2, 3), (2, 1, 3)
{
Expand All @@ -84,7 +84,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
[[[0.273476]], [[0.555539]]],
[[[0.210949], [0.295351]], [[0.599976], [0.428522]]],
],
[ # shape: (2, 2, 3), (2, 1, 3)
{"include_background": False, "to_onehot_y": True, "smooth_nr": 1e-8, "smooth_dr": 1e-8},
Expand Down Expand Up @@ -112,7 +112,7 @@
"input": torch.tensor([[[0.0, 10.0, 10.0, 10.0], [10.0, 0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1, 1, 0, 0]]]),
},
0.250023,
0.26669,
],
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
{"include_background": True, "other_act": torch.tanh, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
Expand All @@ -134,7 +134,7 @@
"input": torch.tensor([[[-1.0, 0.0, 1.0], [1.0, 0.0, -1.0]], [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0]]]),
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
},
-0.097833,
-8.55485,
],
]

Expand Down

0 comments on commit 531a631

Please sign in to comment.