Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
strategy:
fail-fast: false
matrix:
os: [windows-latest, ubuntu-latest]
os: [ubuntu-latest]
python-version: ["3.9"]
runs-on: ${{ matrix.os }}
env:
Expand Down
21 changes: 16 additions & 5 deletions monai/networks/blocks/localnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def forward(self, x) -> Tuple[torch.Tensor, torch.Tensor]:

class LocalNetUpSampleBlock(nn.Module):
"""
A up-sample module that can be used for LocalNet, based on:
An up-sample module that can be used for LocalNet, based on:
`Weakly-supervised convolutional neural networks for multimodal image registration
<https://doi.org/10.1016/j.media.2018.07.002>`_.
`Label-driven weakly-supervised learning for multimodal deformable image registration
Expand All @@ -176,12 +176,21 @@ class LocalNetUpSampleBlock(nn.Module):
DeepReg (https://github.com/DeepRegNet/DeepReg)
"""

def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> None:
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
mode: str = "nearest",
align_corners: Optional[bool] = None,
) -> None:
"""
Args:
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
mode: interpolation mode of the additive upsampling, default to 'nearest'.
align_corners: whether to align corners for the additive upsampling, default to None.
Raises:
ValueError: when ``in_channels != 2 * out_channels``
"""
Expand All @@ -199,9 +208,11 @@ def __init__(self, spatial_dims: int, in_channels: int, out_channels: int) -> No
f"got in_channels={in_channels}, out_channels={out_channels}"
)
self.out_channels = out_channels
self.mode = mode
self.align_corners = align_corners

def addictive_upsampling(self, x, mid) -> torch.Tensor:
x = F.interpolate(x, mid.shape[2:])
def additive_upsampling(self, x, mid) -> torch.Tensor:
x = F.interpolate(x, mid.shape[2:], mode=self.mode, align_corners=self.align_corners)
# [(batch, out_channels, ...), (batch, out_channels, ...)]
x = x.split(split_size=int(self.out_channels), dim=1)
# (batch, out_channels, ...)
Expand All @@ -226,7 +237,7 @@ def forward(self, x, mid) -> torch.Tensor:
"expecting mid spatial dimensions be exactly the double of x spatial dimensions, "
f"got x of shape {x.shape}, mid of shape {mid.shape}"
)
h0 = self.deconv_block(x) + self.addictive_upsampling(x, mid)
h0 = self.deconv_block(x) + self.additive_upsampling(x, mid)
r1 = h0 + mid
r2 = self.conv_block(h0)
out: torch.Tensor = self.residual_block(r2, r1)
Expand Down
10 changes: 9 additions & 1 deletion monai/networks/blocks/regunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,8 @@ def __init__(
out_channels: int,
kernel_initializer: Optional[str] = "kaiming_uniform",
activation: Optional[str] = None,
mode: str = "nearest",
align_corners: Optional[bool] = None,
):
"""

Expand All @@ -211,6 +213,8 @@ def __init__(
out_channels: number of output channels
kernel_initializer: kernel initializer
activation: kernel activation function
mode: feature map interpolation mode, default to "nearest".
align_corners: whether to align corners for feature map interpolation.
"""
super().__init__()
self.extract_levels = extract_levels
Expand All @@ -228,6 +232,8 @@ def __init__(
for d in extract_levels
]
)
self.mode = mode
self.align_corners = align_corners

def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor:
"""
Expand All @@ -240,7 +246,9 @@ def forward(self, x: List[torch.Tensor], image_size: List[int]) -> torch.Tensor:
Tensor of shape (batch, `out_channels`, size1, size2, size3), where (size1, size2, size3) = ``image_size``
"""
feature_list = [
F.interpolate(layer(x[self.max_level - level]), size=image_size)
F.interpolate(
layer(x[self.max_level - level]), size=image_size, mode=self.mode, align_corners=self.align_corners
)
for layer, level in zip(self.layers, self.extract_levels)
]
out: torch.Tensor = torch.mean(torch.stack(feature_list, dim=0), dim=0)
Expand Down
31 changes: 25 additions & 6 deletions monai/networks/nets/regunet.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,14 +337,23 @@ def build_output_block(self):


class AdditiveUpSampleBlock(nn.Module):
def __init__(self, spatial_dims: int, in_channels: int, out_channels: int):
def __init__(
self,
spatial_dims: int,
in_channels: int,
out_channels: int,
mode: str = "nearest",
align_corners: Optional[bool] = None,
):
super().__init__()
self.deconv = get_deconv_block(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels)
self.mode = mode
self.align_corners = align_corners

def forward(self, x: torch.Tensor) -> torch.Tensor:
output_size = [size * 2 for size in x.shape[2:]]
deconved = self.deconv(x)
resized = F.interpolate(x, output_size)
resized = F.interpolate(x, output_size, mode=self.mode, align_corners=self.align_corners)
resized = torch.sum(torch.stack(resized.split(split_size=resized.shape[1] // 2, dim=1), dim=-1), dim=-1)
out: torch.Tensor = deconved + resized
return out
Expand Down Expand Up @@ -372,8 +381,10 @@ def __init__(
out_activation: Optional[str] = None,
out_channels: int = 3,
pooling: bool = True,
use_addictive_sampling: bool = True,
use_additive_sampling: bool = True,
concat_skip: bool = False,
mode: str = "nearest",
align_corners: Optional[bool] = None,
):
"""
Args:
Expand All @@ -385,10 +396,14 @@ def __init__(
out_channels: number of channels for the output
extract_levels: list, which levels from net to extract. The maximum level must equal to ``depth``
pooling: for down-sampling, use non-parameterized pooling if true, otherwise use conv3d
use_addictive_sampling: whether use additive up-sampling layer for decoding.
use_additive_sampling: whether use additive up-sampling layer for decoding.
concat_skip: when up-sampling, concatenate skipped tensor if true, otherwise use addition
mode: mode for interpolation when use_additive_sampling, default is "nearest".
align_corners: align_corners for interpolation when use_additive_sampling, default is None.
"""
self.use_additive_upsampling = use_addictive_sampling
self.use_additive_upsampling = use_additive_sampling
self.mode = mode
self.align_corners = align_corners
super().__init__(
spatial_dims=spatial_dims,
in_channels=in_channels,
Expand All @@ -412,7 +427,11 @@ def build_bottom_block(self, in_channels: int, out_channels: int):
def build_up_sampling_block(self, in_channels: int, out_channels: int) -> nn.Module:
if self.use_additive_upsampling:
return AdditiveUpSampleBlock(
spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels
spatial_dims=self.spatial_dims,
in_channels=in_channels,
out_channels=out_channels,
mode=self.mode,
align_corners=self.align_corners,
)

return get_deconv_block(spatial_dims=self.spatial_dims, in_channels=in_channels, out_channels=out_channels)
2 changes: 2 additions & 0 deletions tests/test_localnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@
"extract_levels": (0, 1),
"pooling": False,
"concat_skip": True,
"mode": "bilinear",
"align_corners": True,
},
(1, 2, 16, 16),
(1, 2, 16, 16),
Expand Down
12 changes: 11 additions & 1 deletion tests/test_localnet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,17 @@
[{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 4, "kernel_size": 3}] for spatial_dims in [2, 3]
]

TEST_CASE_UP_SAMPLE = [[{"spatial_dims": spatial_dims, "in_channels": 4, "out_channels": 2}] for spatial_dims in [2, 3]]
TEST_CASE_UP_SAMPLE = [
[
{
"spatial_dims": spatial_dims,
"in_channels": 4,
"out_channels": 2,
"mode": "bilinear" if spatial_dims == 2 else "trilinear",
}
]
for spatial_dims in [2, 3]
]

TEST_CASE_EXTRACT = [
[{"spatial_dims": spatial_dims, "in_channels": 2, "out_channels": 3, "act": act, "initializer": initializer}]
Expand Down
1 change: 1 addition & 0 deletions tests/test_regunet_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
"out_channels": 1,
"kernel_initializer": "zeros",
"activation": "sigmoid",
"mode": "trilinear",
},
[(1, 3, 2, 2, 2), (1, 2, 4, 4, 4), (1, 1, 8, 8, 8)],
(3, 3, 3),
Expand Down