Skip to content

Commit 1d9800f

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into 4922-step-1
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
2 parents 52b097e + 8b1f0c3 commit 1d9800f

9 files changed

+148
-68
lines changed

monai/losses/dice.py

Lines changed: 43 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -60,12 +60,12 @@ def __init__(
6060
include_background: if False, channel index 0 (background category) is excluded from the calculation.
6161
if the non-background segmentations are small compared to the total image size they can get overwhelmed
6262
by the signal from the background so excluding it in such cases helps convergence.
63-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
63+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
64+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
6465
sigmoid: if True, apply a sigmoid function to the prediction.
6566
softmax: if True, apply a softmax function to the prediction.
66-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
67-
other activation layers, Defaults to ``None``. for example:
68-
`other_act = torch.tanh`.
67+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
68+
``other_act = torch.tanh``.
6969
squared_pred: use squared versions of targets and predictions in the denominator or not.
7070
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
7171
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -247,12 +247,12 @@ def __init__(
247247
"""
248248
Args:
249249
include_background: If False channel index 0 (background category) is excluded from the calculation.
250-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
250+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
251+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
251252
sigmoid: If True, apply a sigmoid function to the prediction.
252253
softmax: If True, apply a softmax function to the prediction.
253-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
254-
other activation layers, Defaults to ``None``. for example:
255-
`other_act = torch.tanh`.
254+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
255+
``other_act = torch.tanh``.
256256
w_type: {``"square"``, ``"simple"``, ``"uniform"``}
257257
Type of function to transform ground truth volume to a weight factor. Defaults to ``"square"``.
258258
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -639,14 +639,14 @@ def __init__(
639639
``reduction`` is used for both losses and other parameters are only used for dice loss.
640640
641641
include_background: if False channel index 0 (background category) is excluded from the calculation.
642-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
642+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
643+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
643644
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
644645
don't need to specify activation function for `CrossEntropyLoss`.
645646
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
646647
don't need to specify activation function for `CrossEntropyLoss`.
647-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
648-
other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
649-
only used by the `DiceLoss`, don't need to specify activation function for `CrossEntropyLoss`.
648+
other_act: callable function to execute other activation layers, Defaults to ``None``. for example:
649+
``other_act = torch.tanh``. only used by the `DiceLoss`, not for the `CrossEntropyLoss`.
650650
squared_pred: use squared versions of targets and predictions in the denominator or not.
651651
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
652652
reduction: {``"mean"``, ``"sum"``}
@@ -728,7 +728,10 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
728728
729729
"""
730730
if len(input.shape) != len(target.shape):
731-
raise ValueError("the number of dimensions for input and target should be the same.")
731+
raise ValueError(
732+
"the number of dimensions for input and target should be the same, "
733+
f"got shape {input.shape} and {target.shape}."
734+
)
732735

733736
dice_loss = self.dice(input, target)
734737
ce_loss = self.ce(input, target)
@@ -743,6 +746,10 @@ class DiceFocalLoss(_Loss):
743746
The details of Dice loss is shown in ``monai.losses.DiceLoss``.
744747
The details of Focal Loss is shown in ``monai.losses.FocalLoss``.
745748
749+
``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for the focal loss.
750+
``include_background`` and ``reduction`` are used for both losses
751+
and other parameters are only used for dice loss.
752+
746753
"""
747754

748755
def __init__(
@@ -765,18 +772,15 @@ def __init__(
765772
) -> None:
766773
"""
767774
Args:
768-
``gamma``, ``focal_weight`` and ``lambda_focal`` are only used for focal loss.
769-
``include_background``, ``to_onehot_y``and ``reduction`` are used for both losses
770-
and other parameters are only used for dice loss.
771775
include_background: if False channel index 0 (background category) is excluded from the calculation.
772-
to_onehot_y: whether to convert `y` into the one-hot format. Defaults to False.
776+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
777+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
773778
sigmoid: if True, apply a sigmoid function to the prediction, only used by the `DiceLoss`,
774779
don't need to specify activation function for `FocalLoss`.
775780
softmax: if True, apply a softmax function to the prediction, only used by the `DiceLoss`,
776781
don't need to specify activation function for `FocalLoss`.
777-
other_act: if don't want to use `sigmoid` or `softmax`, use other callable function to execute
778-
other activation layers, Defaults to ``None``. for example: `other_act = torch.tanh`.
779-
only used by the `DiceLoss`, don't need to specify activation function for `FocalLoss`.
782+
other_act: callable function to execute other activation layers, Defaults to ``None``.
783+
for example: `other_act = torch.tanh`. only used by the `DiceLoss`, not for `FocalLoss`.
780784
squared_pred: use squared versions of targets and predictions in the denominator or not.
781785
jaccard: compute Jaccard Index (soft IoU) instead of dice or not.
782786
reduction: {``"none"``, ``"mean"``, ``"sum"``}
@@ -803,6 +807,8 @@ def __init__(
803807
"""
804808
super().__init__()
805809
self.dice = DiceLoss(
810+
include_background=include_background,
811+
to_onehot_y=False,
806812
sigmoid=sigmoid,
807813
softmax=softmax,
808814
other_act=other_act,
@@ -813,15 +819,20 @@ def __init__(
813819
smooth_dr=smooth_dr,
814820
batch=batch,
815821
)
816-
self.focal = FocalLoss(gamma=gamma, weight=focal_weight, reduction=reduction)
822+
self.focal = FocalLoss(
823+
include_background=include_background,
824+
to_onehot_y=False,
825+
gamma=gamma,
826+
weight=focal_weight,
827+
reduction=reduction,
828+
)
817829
if lambda_dice < 0.0:
818830
raise ValueError("lambda_dice should be no less than 0.0.")
819831
if lambda_focal < 0.0:
820832
raise ValueError("lambda_focal should be no less than 0.0.")
821833
self.lambda_dice = lambda_dice
822834
self.lambda_focal = lambda_focal
823835
self.to_onehot_y = to_onehot_y
824-
self.include_background = include_background
825836

826837
def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
827838
"""
@@ -836,24 +847,16 @@ def forward(self, input: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
836847
837848
"""
838849
if len(input.shape) != len(target.shape):
839-
raise ValueError("the number of dimensions for input and target should be the same.")
840-
841-
n_pred_ch = input.shape[1]
842-
850+
raise ValueError(
851+
"the number of dimensions for input and target should be the same, "
852+
f"got shape {input.shape} and {target.shape}."
853+
)
843854
if self.to_onehot_y:
855+
n_pred_ch = input.shape[1]
844856
if n_pred_ch == 1:
845857
warnings.warn("single channel prediction, `to_onehot_y=True` ignored.")
846858
else:
847859
target = one_hot(target, num_classes=n_pred_ch)
848-
849-
if not self.include_background:
850-
if n_pred_ch == 1:
851-
warnings.warn("single channel prediction, `include_background=False` ignored.")
852-
else:
853-
# if skipping background, removing first channel
854-
target = target[:, 1:]
855-
input = input[:, 1:]
856-
857860
dice_loss = self.dice(input, target)
858861
focal_loss = self.focal(input, target)
859862
total_loss: torch.Tensor = self.lambda_dice * dice_loss + self.lambda_focal * focal_loss
@@ -867,11 +870,13 @@ class GeneralizedDiceFocalLoss(torch.nn.modules.loss._Loss):
867870
Args:
868871
include_background (bool, optional): if False channel index 0 (background category) is excluded from the calculation.
869872
Defaults to True.
870-
to_onehot_y (bool, optional): whether to convert `y` into the one-hot format. Defaults to False.
873+
to_onehot_y: whether to convert the ``target`` into the one-hot format,
874+
using the number of classes inferred from `input` (``input.shape[1]``). Defaults to False.
871875
sigmoid (bool, optional): if True, apply a sigmoid function to the prediction. Defaults to False.
872876
softmax (bool, optional): if True, apply a softmax function to the prediction. Defaults to False.
873-
other_act (Optional[Callable], optional): if don't want to use sigmoid or softmax, use other callable
874-
function to execute other activation layers. Defaults to None.
877+
other_act (Optional[Callable], optional): callable function to execute other activation layers,
878+
Defaults to ``None``. for example: `other_act = torch.tanh`.
879+
only used by the `GeneralizedDiceLoss`, not for the `FocalLoss`.
875880
w_type (Union[Weight, str], optional): {``"square"``, ``"simple"``, ``"uniform"``}. Type of function to transform
876881
ground-truth volume to a weight factor. Defaults to ``"square"``.
877882
reduction (Union[LossReduction, str], optional): {``"none"``, ``"mean"``, ``"sum"``}. Specified the reduction to

monai/networks/nets/attentionunet.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -143,12 +143,27 @@ def forward(self, g: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
143143

144144

145145
class AttentionLayer(nn.Module):
146-
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int, submodule: nn.Module, dropout=0.0):
146+
def __init__(
147+
self,
148+
spatial_dims: int,
149+
in_channels: int,
150+
out_channels: int,
151+
submodule: nn.Module,
152+
up_kernel_size=3,
153+
strides=2,
154+
dropout=0.0,
155+
):
147156
super().__init__()
148157
self.attention = AttentionBlock(
149158
spatial_dims=spatial_dims, f_g=in_channels, f_l=in_channels, f_int=in_channels // 2
150159
)
151-
self.upconv = UpConv(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=in_channels, strides=2)
160+
self.upconv = UpConv(
161+
spatial_dims=spatial_dims,
162+
in_channels=out_channels,
163+
out_channels=in_channels,
164+
strides=strides,
165+
kernel_size=up_kernel_size,
166+
)
152167
self.merge = Convolution(
153168
spatial_dims=spatial_dims, in_channels=2 * in_channels, out_channels=in_channels, dropout=dropout
154169
)
@@ -174,7 +189,7 @@ class AttentionUnet(nn.Module):
174189
channels (Sequence[int]): sequence of channels. Top block first. The length of `channels` should be no less than 2.
175190
strides (Sequence[int]): stride to use for convolutions.
176191
kernel_size: convolution kernel size.
177-
upsample_kernel_size: convolution kernel size for transposed convolution layers.
192+
up_kernel_size: convolution kernel size for transposed convolution layers.
178193
dropout: dropout ratio. Defaults to no dropout.
179194
"""
180195

@@ -210,9 +225,9 @@ def __init__(
210225
)
211226
self.up_kernel_size = up_kernel_size
212227

213-
def _create_block(channels: Sequence[int], strides: Sequence[int], level: int = 0) -> nn.Module:
228+
def _create_block(channels: Sequence[int], strides: Sequence[int]) -> nn.Module:
214229
if len(channels) > 2:
215-
subblock = _create_block(channels[1:], strides[1:], level=level + 1)
230+
subblock = _create_block(channels[1:], strides[1:])
216231
return AttentionLayer(
217232
spatial_dims=spatial_dims,
218233
in_channels=channels[0],
@@ -227,17 +242,19 @@ def _create_block(channels: Sequence[int], strides: Sequence[int], level: int =
227242
),
228243
subblock,
229244
),
245+
up_kernel_size=self.up_kernel_size,
246+
strides=strides[0],
230247
dropout=dropout,
231248
)
232249
else:
233250
# the next layer is the bottom so stop recursion,
234-
# create the bottom layer as the sublock for this layer
235-
return self._get_bottom_layer(channels[0], channels[1], strides[0], level=level + 1)
251+
# create the bottom layer as the subblock for this layer
252+
return self._get_bottom_layer(channels[0], channels[1], strides[0])
236253

237254
encdec = _create_block(self.channels, self.strides)
238255
self.model = nn.Sequential(head, encdec, reduce_channels)
239256

240-
def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, level: int) -> nn.Module:
257+
def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int) -> nn.Module:
241258
return AttentionLayer(
242259
spatial_dims=self.dimensions,
243260
in_channels=in_channels,
@@ -249,6 +266,8 @@ def _get_bottom_layer(self, in_channels: int, out_channels: int, strides: int, l
249266
strides=strides,
250267
dropout=self.dropout,
251268
),
269+
up_kernel_size=self.up_kernel_size,
270+
strides=strides,
252271
dropout=self.dropout,
253272
)
254273

monai/transforms/utils.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -505,11 +505,11 @@ def generate_pos_neg_label_crop_centers(
505505
raise ValueError("No sampling location available.")
506506

507507
if len(fg_indices) == 0 or len(bg_indices) == 0:
508+
pos_ratio = 0 if len(fg_indices) == 0 else 1
508509
warnings.warn(
509-
f"N foreground {len(fg_indices)}, N background {len(bg_indices)},"
510-
"unable to generate class balanced samples."
510+
f"Num foregrounds {len(fg_indices)}, Num backgrounds {len(bg_indices)}, "
511+
f"unable to generate class balanced samples, setting `pos_ratio` to {pos_ratio}."
511512
)
512-
pos_ratio = 0 if fg_indices.size == 0 else 1
513513

514514
for _ in range(num_samples):
515515
indices_to_use = fg_indices if rand_state.rand() < pos_ratio else bg_indices

monai/visualize/gradient_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_gra
9090
x.requires_grad = True
9191

9292
self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs)
93-
grad: torch.Tensor = x.grad.detach()
93+
grad: torch.Tensor = x.grad.detach() # type: ignore
9494
return grad
9595

9696
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:

tests/test_attentionunet.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def test_attentionunet(self):
3939
shape = (3, 1) + (92,) * dims
4040
input = torch.rand(*shape)
4141
model = att.AttentionUnet(
42-
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), strides=(2, 2)
42+
spatial_dims=dims, in_channels=1, out_channels=2, channels=(3, 4, 5), up_kernel_size=5, strides=(1, 2)
4343
)
4444
output = model(input)
4545
self.assertEqual(output.shape[2:], input.shape[2:])

tests/test_dice_focal_loss.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313

1414
import numpy as np
1515
import torch
16+
from parameterized import parameterized
1617

1718
from monai.losses import DiceFocalLoss, DiceLoss, FocalLoss
1819
from tests.utils import test_script_save
@@ -36,17 +37,24 @@ def test_result_onehot_target_include_bg(self):
3637
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)
3738
np.testing.assert_allclose(result, expected_val)
3839

39-
def test_result_no_onehot_no_bg(self):
40-
size = [3, 3, 5, 5]
41-
label = torch.randint(low=0, high=2, size=size)
42-
label = torch.argmax(label, dim=1, keepdim=True)
40+
@parameterized.expand([[[3, 3, 5, 5], True], [[3, 2, 5, 5], False]])
41+
def test_result_no_onehot_no_bg(self, size, onehot):
42+
label = torch.randint(low=0, high=size[1] - 1, size=size)
43+
if onehot:
44+
label = torch.argmax(label, dim=1, keepdim=True)
4345
pred = torch.randn(size)
4446
for reduction in ["sum", "mean", "none"]:
45-
common_params = {"include_background": False, "to_onehot_y": True, "reduction": reduction}
46-
for focal_weight in [2.0, torch.tensor([1.0, 2.0]), (2.0, 1)]:
47+
for focal_weight in [2.0] + [] if size[1] != 3 else [torch.tensor([1.0, 2.0]), (2.0, 1)]:
4748
for lambda_focal in [0.5, 1.0, 1.5]:
49+
common_params = {
50+
"include_background": False,
51+
"softmax": True,
52+
"to_onehot_y": onehot,
53+
"reduction": reduction,
54+
}
4855
dice_focal = DiceFocalLoss(focal_weight=focal_weight, lambda_focal=lambda_focal, **common_params)
4956
dice = DiceLoss(**common_params)
57+
common_params.pop("softmax", None)
5058
focal = FocalLoss(weight=focal_weight, **common_params)
5159
result = dice_focal(pred, label)
5260
expected_val = dice(pred, label) + lambda_focal * focal(pred, label)

tests/test_generate_pos_neg_label_crop_centers.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,20 @@
3131
list,
3232
2,
3333
3,
34-
]
34+
],
35+
[
36+
{
37+
"spatial_size": [2, 2, 2],
38+
"num_samples": 2,
39+
"pos_ratio": 0.0,
40+
"label_spatial_shape": [3, 3, 3],
41+
"fg_indices": [],
42+
"bg_indices": [3, 12, 21],
43+
},
44+
list,
45+
2,
46+
3,
47+
],
3548
]
3649

3750

0 commit comments

Comments
 (0)