Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SwinUNETR refactored img_size parameter and removed checkpointing dep… #7093

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 33 additions & 9 deletions monai/networks/nets/swin_unetr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@
import torch.nn.functional as F
import torch.utils.checkpoint as checkpoint
from torch.nn import LayerNorm
from typing_extensions import Final

from monai.networks.blocks import MLPBlock as Mlp
from monai.networks.blocks import PatchEmbed, UnetOutBlock, UnetrBasicBlock, UnetrUpBlock
from monai.networks.layers import DropPath, trunc_normal_
from monai.utils import ensure_tuple_rep, look_up_option, optional_import
from monai.utils.deprecate_utils import deprecated_arg

rearrange, _ = optional_import("einops", name="rearrange")

Expand All @@ -49,6 +51,15 @@ class SwinUNETR(nn.Module):
<https://arxiv.org/abs/2201.01266>"
"""

patch_size: Final[int] = 2

@deprecated_arg(
name="img_size",
since="1.3",
removed="1.5",
msg_suffix="The img_size argument is not required anymore and "
"checks on the input size are run during forward().",
)
def __init__(
self,
img_size: Sequence[int] | int,
Expand All @@ -69,7 +80,10 @@ def __init__(
) -> None:
"""
Args:
img_size: dimension of input image.
img_size: spatial dimension of input image.
This argument is only used for checking that the input image size is divisible by the patch size.
The tensor passed to forward() can have a dynamic shape as long as its spatial dimensions are divisible by 2**5.
It will be removed in an upcoming version.
in_channels: dimension of input channels.
out_channels: dimension of output channels.
feature_size: dimension of network feature size.
Expand Down Expand Up @@ -103,16 +117,13 @@ def __init__(
super().__init__()

img_size = ensure_tuple_rep(img_size, spatial_dims)
patch_size = ensure_tuple_rep(2, spatial_dims)
patch_sizes = ensure_tuple_rep(self.patch_size, spatial_dims)
window_size = ensure_tuple_rep(7, spatial_dims)

if spatial_dims not in (2, 3):
raise ValueError("spatial dimension should be 2 or 3.")

for m, p in zip(img_size, patch_size):
for i in range(5):
if m % np.power(p, i + 1) != 0:
raise ValueError("input image size (img_size) should be divisible by stage-wise image resolution.")
self._check_input_size(img_size)

if not (0 <= drop_rate <= 1):
raise ValueError("dropout rate should be between 0 and 1.")
Expand All @@ -132,7 +143,7 @@ def __init__(
in_chans=in_channels,
embed_dim=feature_size,
window_size=window_size,
patch_size=patch_size,
patch_size=patch_sizes,
depths=depths,
num_heads=num_heads,
mlp_ratio=4.0,
Expand Down Expand Up @@ -297,7 +308,20 @@ def load_from(self, weights):
weights["state_dict"]["module.layers4.0.downsample.norm.bias"]
)

@torch.jit.unused
def _check_input_size(self, spatial_shape):
img_size = np.array(spatial_shape)
remainder = (img_size % np.power(self.patch_size, 5)) > 0
if remainder.any():
wrong_dims = (np.where(remainder)[0] + 2).tolist()
raise ValueError(
f"spatial dimensions {wrong_dims} of input image (spatial shape: {spatial_shape})"
f" must be divisible by {self.patch_size}**5."
)

def forward(self, x_in):
if not torch.jit.is_scripting():
self._check_input_size(x_in.shape[2:])
hidden_states_out = self.swinViT(x_in, self.normalize)
enc0 = self.encoder1(x_in)
enc1 = self.encoder2(hidden_states_out[0])
Expand Down Expand Up @@ -669,12 +693,12 @@ def load_from(self, weights, n_block, layer):
def forward(self, x, mask_matrix):
shortcut = x
if self.use_checkpoint:
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix)
x = checkpoint.checkpoint(self.forward_part1, x, mask_matrix, use_reentrant=False)
else:
x = self.forward_part1(x, mask_matrix)
x = shortcut + self.drop_path(x)
if self.use_checkpoint:
x = x + checkpoint.checkpoint(self.forward_part2, x)
x = x + checkpoint.checkpoint(self.forward_part2, x, use_reentrant=False)
else:
x = x + self.forward_part2(x)
return x
Expand Down
2 changes: 1 addition & 1 deletion runtests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ do
doBlackFormat=true
doIsortFormat=true
doFlake8Format=true
doPylintFormat=true
# doPylintFormat=true # https://github.com/Project-MONAI/MONAI/issues/7094
doRuffFormat=true
doCopyRight=true
;;
Expand Down