Skip to content

Commit 625967c

Browse files
Lucas-rbntKumoLiupre-commit-ci[bot]
authored
harmonization and clarification of dice losses variants docs and associated tests (#7587)
### Description This PR aims to clarify and harmonise the code for the DiceLoss variants in the `monai/losses/dice.py` file. With the `to_onehot_y` `softmax` and `sigmoid` arguments, I didn't necessarily understand the ValueError that occurred when I passed a target of size NH[WD]. I had a bit of trouble reading the documentation and understanding it. I thought that they had to be the same shape as they are displayed, unlike the number of dimensions in the input, so I added that. Besides, in the documentation is written: ```python """ raises: ValueError: When number of channels for target is neither 1 nor the same as input. """ ``` Trying to reproduce this, we give an input with a number of channels $N$ and target a number of channels of $M$, with $M \neq N$ and $M > 1$. ```python loss = DiceCELoss() input = torch.rand(1, 4, 3, 3) target = torch.randn(1, 2, 3, 3) loss(input, target) >: AssertionError: ground truth has different shape (torch.Size([1, 2, 3, 3])) from input (torch.Size([1, 4, 3, 3])) ``` This error in the Dice is an `AssertionError` and not a `ValueError` as expected and the explanation can be confusing and doesn't give a clear idea of the error here. The classes concerned and harmonised are `DiceFocalLoss`, `DiceCELoss` and `GeneralizedDiceFocalLoss` with the addition of tests that behave correctly and handle this harmonisation. Also, feel free to modify or make suggestions regarding the changes made in the docstring to make them more understandable (in my opinion, but other readers and users will probably have a different view). ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Lucas Robinet <robinet.lucas@iuct-oncopole.fr> Signed-off-by: Lucas Robinet <luca.robinet@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 195d7dd commit 625967c

File tree

4 files changed

+73
-15
lines changed

4 files changed

+73
-15
lines changed

monai/losses/dice.py

Lines changed: 35 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -778,12 +778,22 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
778778
779779
Raises:
780780
ValueError: When number of dimensions for input and target are different.
781-
ValueError: When number of channels for target is neither 1 nor the same as input.
781+
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
782+
783+
Returns:
784+
torch.Tensor: value of the loss.
782785
783786
"""
784-
if len(input.shape) != len(target.shape):
787+
if input.dim() != target.dim():
785788
raise ValueError(
786789
"the number of dimensions for input and target should be the same, "
790+
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
791+
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
792+
)
793+
794+
if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
795+
raise ValueError(
796+
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
787797
f"got shape {input.shape} and {target.shape}."
788798
)
789799

@@ -899,14 +909,24 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
899909
900910
Raises:
901911
ValueError: When number of dimensions for input and target are different.
902-
ValueError: When number of channels for target is neither 1 nor the same as input.
912+
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
903913
914+
Returns:
915+
torch.Tensor: value of the loss.
904916
"""
905-
if len(input.shape) != len(target.shape):
917+
if input.dim() != target.dim():
906918
raise ValueError(
907919
"the number of dimensions for input and target should be the same, "
920+
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
921+
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
922+
)
923+
924+
if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
925+
raise ValueError(
926+
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
908927
f"got shape {input.shape} and {target.shape}."
909928
)
929+
910930
if self.to_onehot_y:
911931
n_pred_ch = input.shape[1]
912932
if n_pred_ch == 1:
@@ -1015,15 +1035,23 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
10151035
target (torch.Tensor): the shape should be BNH[WD] or B1H[WD].
10161036
10171037
Raises:
1018-
ValueError: When the input and target tensors have different numbers of dimensions, or the target
1019-
channel isn't either one-hot encoded or categorical with the same shape of the input.
1038+
ValueError: When number of dimensions for input and target are different.
1039+
ValueError: When number of channels for target is neither 1 (without one-hot encoding) nor the same as input.
10201040
10211041
Returns:
10221042
torch.Tensor: value of the loss.
10231043
"""
10241044
if input.dim() != target.dim():
10251045
raise ValueError(
1026-
f"Input - {input.shape} - and target - {target.shape} - must have the same number of dimensions."
1046+
"the number of dimensions for input and target should be the same, "
1047+
f"got shape {input.shape} (nb dims: {len(input.shape)}) and {target.shape} (nb dims: {len(target.shape)}). "
1048+
"if target is not one-hot encoded, please provide a tensor with shape B1H[WD]."
1049+
)
1050+
1051+
if target.shape[1] != 1 and target.shape[1] != input.shape[1]:
1052+
raise ValueError(
1053+
"number of channels for target is neither 1 (without one-hot encoding) nor the same as input, "
1054+
f"got shape {input.shape} and {target.shape}."
10271055
)
10281056

10291057
gdl_loss = self.generalized_dice(input, target)

tests/test_dice_ce_loss.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,20 @@ def test_result(self, input_param, input_data, expected_val):
9393
result = diceceloss(**input_data)
9494
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, atol=1e-4, rtol=1e-4)
9595

96-
# def test_ill_shape(self):
97-
# loss = DiceCELoss()
98-
# with self.assertRaisesRegex(ValueError, ""):
99-
# loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
96+
def test_ill_shape(self):
97+
loss = DiceCELoss()
98+
with self.assertRaises(AssertionError):
99+
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))
100+
101+
def test_ill_shape2(self):
102+
loss = DiceCELoss()
103+
with self.assertRaises(ValueError):
104+
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
105+
106+
def test_ill_shape3(self):
107+
loss = DiceCELoss()
108+
with self.assertRaises(ValueError):
109+
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))
100110

101111
# def test_ill_reduction(self):
102112
# with self.assertRaisesRegex(ValueError, ""):

tests/test_dice_focal_loss.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,18 @@ def test_result_no_onehot_no_bg(self, size, onehot):
6969

7070
def test_ill_shape(self):
7171
loss = DiceFocalLoss()
72-
with self.assertRaisesRegex(ValueError, ""):
73-
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
72+
with self.assertRaises(AssertionError):
73+
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))
74+
75+
def test_ill_shape2(self):
76+
loss = DiceFocalLoss()
77+
with self.assertRaises(ValueError):
78+
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
79+
80+
def test_ill_shape3(self):
81+
loss = DiceFocalLoss()
82+
with self.assertRaises(ValueError):
83+
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))
7484

7585
def test_ill_lambda(self):
7686
with self.assertRaisesRegex(ValueError, ""):

tests/test_generalized_dice_focal_loss.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,18 @@ def test_result_no_onehot_no_bg(self):
5959

6060
def test_ill_shape(self):
6161
loss = GeneralizedDiceFocalLoss()
62-
with self.assertRaisesRegex(ValueError, ""):
63-
loss(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
62+
with self.assertRaises(AssertionError):
63+
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 2, 5)))
64+
65+
def test_ill_shape2(self):
66+
loss = GeneralizedDiceFocalLoss()
67+
with self.assertRaises(ValueError):
68+
loss.forward(torch.ones((1, 2, 3)), torch.ones((1, 1, 2, 3)))
69+
70+
def test_ill_shape3(self):
71+
loss = GeneralizedDiceFocalLoss()
72+
with self.assertRaises(ValueError):
73+
loss.forward(torch.ones((1, 3, 4, 4)), torch.ones((1, 2, 4, 4)))
6474

6575
def test_ill_lambda(self):
6676
with self.assertRaisesRegex(ValueError, ""):

0 commit comments

Comments
 (0)