diff --git a/monai/networks/nets/swin_unetr.py b/monai/networks/nets/swin_unetr.py index 9f8204968f..59fdb41815 100644 --- a/monai/networks/nets/swin_unetr.py +++ b/monai/networks/nets/swin_unetr.py @@ -65,6 +65,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", + use_v2=False, ) -> None: """ Args: @@ -84,6 +85,7 @@ def __init__( downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. The default is currently `"merging"` (the original version defined in v0.9.0). + use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage. Examples:: @@ -142,6 +144,7 @@ def __init__( use_checkpoint=use_checkpoint, spatial_dims=spatial_dims, downsample=look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample, + use_v2=use_v2, ) self.encoder1 = UnetrBasicBlock( @@ -921,6 +924,7 @@ def __init__( use_checkpoint: bool = False, spatial_dims: int = 3, downsample="merging", + use_v2=False, ) -> None: """ Args: @@ -942,6 +946,7 @@ def __init__( downsample: module used for downsampling, available options are `"mergingv2"`, `"merging"` and a user-specified `nn.Module` following the API defined in :py:class:`monai.networks.nets.PatchMerging`. The default is currently `"merging"` (the original version defined in v0.9.0). + use_v2: using swinunetr_v2, which adds a residual convolution block at the beggining of each swin stage. """ super().__init__() @@ -959,10 +964,16 @@ def __init__( ) self.pos_drop = nn.Dropout(p=drop_rate) dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + self.use_v2 = use_v2 self.layers1 = nn.ModuleList() self.layers2 = nn.ModuleList() self.layers3 = nn.ModuleList() self.layers4 = nn.ModuleList() + if self.use_v2: + self.layers1c = nn.ModuleList() + self.layers2c = nn.ModuleList() + self.layers3c = nn.ModuleList() + self.layers4c = nn.ModuleList() down_sample_mod = look_up_option(downsample, MERGING_MODE) if isinstance(downsample, str) else downsample for i_layer in range(self.num_layers): layer = BasicLayer( @@ -987,6 +998,25 @@ def __init__( self.layers3.append(layer) elif i_layer == 3: self.layers4.append(layer) + if self.use_v2: + layerc = UnetrBasicBlock( + spatial_dims=3, + in_channels=embed_dim * 2**i_layer, + out_channels=embed_dim * 2**i_layer, + kernel_size=3, + stride=1, + norm_name="instance", + res_block=True, + ) + if i_layer == 0: + self.layers1c.append(layerc) + elif i_layer == 1: + self.layers2c.append(layerc) + elif i_layer == 2: + self.layers3c.append(layerc) + elif i_layer == 3: + self.layers4c.append(layerc) + self.num_features = int(embed_dim * 2 ** (self.num_layers - 1)) def proj_out(self, x, normalize=False): @@ -1008,12 +1038,20 @@ def forward(self, x, normalize=True): x0 = self.patch_embed(x) x0 = self.pos_drop(x0) x0_out = self.proj_out(x0, normalize) + if self.use_v2: + x0 = self.layers1c[0](x0.contiguous()) x1 = self.layers1[0](x0.contiguous()) x1_out = self.proj_out(x1, normalize) + if self.use_v2: + x1 = self.layers2c[0](x1.contiguous()) x2 = self.layers2[0](x1.contiguous()) x2_out = self.proj_out(x2, normalize) + if self.use_v2: + x2 = self.layers3c[0](x2.contiguous()) x3 = self.layers3[0](x2.contiguous()) x3_out = self.proj_out(x3, normalize) + if self.use_v2: + x3 = self.layers4c[0](x3.contiguous()) x4 = self.layers4[0](x3.contiguous()) x4_out = self.proj_out(x4, normalize) return [x0_out, x1_out, x2_out, x3_out, x4_out]