Skip to content

Commit adff162

Browse files
committed
Adjust execution order of activation and masking in MaskedDiceLoss
Signed-off-by: ytl0623 <david89062388@gmail.com>
1 parent 57fdd59 commit adff162

File tree

2 files changed

+32
-4
lines changed

2 files changed

+32
-4
lines changed

monai/losses/dice.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,22 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
244244
Args follow :py:class:`monai.losses.DiceLoss`.
245245
"""
246246
super().__init__(*args, **kwargs)
247-
self.spatial_weighted = MaskedLoss(loss=super().forward)
247+
self.dice = DiceLoss(
248+
include_background=self.include_background,
249+
to_onehot_y=self.to_onehot_y,
250+
sigmoid=False,
251+
softmax=False,
252+
other_act=None,
253+
squared_pred=self.squared_pred,
254+
jaccard=self.jaccard,
255+
reduction=self.reduction,
256+
smooth_nr=self.smooth_nr,
257+
smooth_dr=self.smooth_dr,
258+
batch=self.batch,
259+
weight=self.class_weight,
260+
soft_label=self.soft_label,
261+
)
262+
self.spatial_weighted = MaskedLoss(loss=self.dice.forward)
248263

249264
def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
250265
"""
@@ -253,6 +268,19 @@ def forward(self, input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor
253268
target: the shape should be BNH[WD].
254269
mask: the shape should B1H[WD] or 11H[WD].
255270
"""
271+
272+
if self.sigmoid:
273+
input = torch.sigmoid(input)
274+
275+
n_pred_ch = input.shape[1]
276+
if self.softmax:
277+
if n_pred_ch == 1:
278+
warnings.warn("single channel prediction, `softmax=True` ignored.")
279+
else:
280+
input = torch.softmax(input, 1)
281+
282+
if self.other_act is not None:
283+
input = self.other_act(input)
256284
return self.spatial_weighted(input=input, target=target, mask=mask) # type: ignore[no-any-return]
257285

258286

tests/losses/test_masked_dice_loss.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
"target": torch.tensor([[[[1.0, 0.0], [1.0, 1.0]]]]),
2828
"mask": torch.tensor([[[[0.0, 0.0], [1.0, 1.0]]]]),
2929
},
30-
0.500,
30+
0.333333,
3131
],
3232
[ # shape: (2, 1, 2, 2), (2, 1, 2, 2)
3333
{"include_background": True, "sigmoid": True, "smooth_nr": 1e-4, "smooth_dr": 1e-4},
@@ -36,7 +36,7 @@
3636
"target": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 0.0], [1.0, 0.0]]]]),
3737
"mask": torch.tensor([[[[1.0, 1.0], [1.0, 1.0]]], [[[1.0, 1.0], [0.0, 0.0]]]]),
3838
},
39-
0.422969,
39+
0.301128,
4040
],
4141
[ # shape: (2, 2, 3), (2, 1, 3)
4242
{"include_background": False, "to_onehot_y": True, "smooth_nr": 0, "smooth_dr": 0},
@@ -54,7 +54,7 @@
5454
"target": torch.tensor([[[1.0, 0.0, 0.0]], [[1.0, 1.0, 0.0]]]),
5555
"mask": torch.tensor([[[1.0, 1.0, 0.0]]]),
5656
},
57-
0.47033,
57+
0.579184,
5858
],
5959
[ # shape: (2, 2, 3), (2, 1, 3)
6060
{

0 commit comments

Comments
 (0)