Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 37 additions & 5 deletions monai/networks/blocks/warp.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,18 +110,29 @@ def get_reference_grid(self, ddf: torch.Tensor, jitter: bool = False, seed: int
self.ref_grid.requires_grad = False
return self.ref_grid

def forward(self, image: torch.Tensor, ddf: torch.Tensor):
def forward(
self, image: torch.Tensor, ddf: torch.Tensor, keypoints: torch.Tensor | None = None
) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]:
"""
Args:
image: Tensor in shape (batch, num_channels, H, W[, D])
ddf: Tensor in the same spatial size as image, in shape (batch, ``spatial_dims``, H, W[, D])
keypoints: Tensor in shape (batch, N, ``spatial_dims``), optional

Returns:
warped_image in the same shape as image (batch, num_channels, H, W[, D])
warped_keypoints in the same shape as keypoints (batch, N, ``spatial_dims``), if keypoints is not None
"""
batch_size = image.shape[0]
spatial_dims = len(image.shape) - 2
if spatial_dims not in (2, 3):
raise NotImplementedError(f"got unsupported spatial_dims={spatial_dims}, currently support 2 or 3.")
if keypoints is not None:
if keypoints.shape[-1] != spatial_dims:
raise ValueError(
f"Given input {spatial_dims}-d image, the last dimension of the input keypoints must be {spatial_dims}, "
f"got {keypoints.shape}."
)
ddf_shape = (image.shape[0], spatial_dims) + tuple(image.shape[2:])
if ddf.shape != ddf_shape:
raise ValueError(
Expand All @@ -136,12 +147,33 @@ def forward(self, image: torch.Tensor, ddf: torch.Tensor):
grid[..., i] = grid[..., i] * 2 / (dim - 1) - 1
index_ordering: list[int] = list(range(spatial_dims - 1, -1, -1))
grid = grid[..., index_ordering] # z, y, x -> x, y, z
return F.grid_sample(
warped_image = F.grid_sample(
image, grid, mode=self._interp_mode, padding_mode=f"{self._padding_mode}", align_corners=True
)

# using csrc resampling
return grid_pull(image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode)
else:
# using csrc resampling
warped_image = grid_pull(
image, grid, bound=self._padding_mode, extrapolate=True, interpolation=self._interp_mode
)
if keypoints is not None:
with torch.no_grad():
offset = torch.as_tensor(image.shape[2:]).to(keypoints) / 2.0
offset = offset.unsqueeze(0).unsqueeze(0)
normalized_keypoints = torch.flip((keypoints - offset) / offset, (-1,))
ddf_keypoints = (
F.grid_sample(
ddf,
normalized_keypoints.view(batch_size, -1, 1, 1, spatial_dims),
mode=self._interp_mode,
padding_mode=f"{self._padding_mode}",
align_corners=True,
)
.view(batch_size, 3, -1)
.permute((0, 2, 1))
)
warped_keypoints = keypoints + ddf_keypoints
return warped_image, warped_keypoints
return warped_image


class DVF2DDF(nn.Module):
Expand Down
41 changes: 39 additions & 2 deletions monai/networks/nets/voxelmorph.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,12 @@ class serves as a wrapper that concatenates the input pair of moving and fixed i
fixed = torch.randn(1, 1, 160, 192, 224)
warped, ddf = net(moving, fixed)

# Example with optional moving_seg and fixed_keypoints
moving_seg = torch.randint(0, 4, (1, 1, 160, 192, 224)).float()
moving_seg = one_hot(moving_seg, num_classes=4)
fixed_keypoints = torch.tensor([[[80, 96, 112], [40, 48, 56]]]).float()
warped_img, warped_seg, warped_keypoints, ddf = net( moving, fixed, moving_seg=moving_seg, fixed_keypoints=fixed_keypoints )

"""

def __init__(
Expand Down Expand Up @@ -440,13 +446,37 @@ def __init__(
self.dvf2ddf = DVF2DDF(num_steps=self.integration_steps, mode="bilinear", padding_mode="zeros")
self.warp = Warp(mode="bilinear", padding_mode="zeros")

def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
def forward(
self,
moving: torch.Tensor,
fixed: torch.Tensor,
moving_seg: torch.Tensor | None = None,
fixed_keypoints: torch.Tensor | None = None,
) -> tuple[torch.Tensor, ...]:
if moving.shape != fixed.shape:
raise ValueError(
"The spatial shape of the moving image should be the same as the spatial shape of the fixed image."
f" Got {moving.shape} and {fixed.shape} instead."
)

if moving_seg is not None:
if moving_seg.shape[0] != moving.shape[0]:
raise ValueError(
f"Batch dimension mismatch: moving_seg={moving_seg.shape[0]}, moving={moving.shape[0]}"
)
if moving_seg.shape[2:] != moving.shape[2:]:
raise ValueError(
"The spatial shape of the moving segmentation must match the spatial shape of the moving image. "
f"Got {moving_seg.shape[2:]} vs {moving.shape[2:]}."
)

if fixed_keypoints is not None:
if fixed_keypoints.shape[-1] != self.spatial_dims:
raise ValueError(
"The last dimension of the fixed keypoints should be equal to the number of spatial dimensions."
f" Got {fixed_keypoints.shape[-1]} and {self.spatial_dims} instead."
)

x = self.backbone(torch.cat([moving, fixed], dim=1))

if x.shape[1] != self.spatial_dims:
Expand All @@ -470,7 +500,14 @@ def forward(self, moving: torch.Tensor, fixed: torch.Tensor) -> tuple[torch.Tens
if self.half_res:
x = F.interpolate(x * 0.5, scale_factor=2.0, mode="trilinear", align_corners=True)

return self.warp(moving, x), x
if moving_seg is None and fixed_keypoints is None:
return self.warp(moving, x), x
elif moving_seg is None and fixed_keypoints is not None:
return *self.warp(moving, x, fixed_keypoints), x
elif moving_seg is not None and fixed_keypoints is None:
return self.warp(moving, x), self.warp(moving_seg, x), x
else:
return self.warp(moving, x), *self.warp(moving_seg, x, fixed_keypoints), x


voxelmorph = VoxelMorph
54 changes: 54 additions & 0 deletions tests/networks/nets/test_voxelmorph.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,18 @@
TEST_CASE_9,
]

TEST_CASE_SEG_0 = [
{"spatial_dims": 3},
(1, 1, 96, 96, 48), # moving image
(1, 1, 96, 96, 48), # fixed image
(1, 2, 96, 96, 48), # moving label
(1, 1, 96, 96, 48), # expected warped moving image
(1, 2, 96, 96, 48), # expected warped moving label
(1, 3, 96, 96, 48), # expected ddf
]

CASES_SEG = [TEST_CASE_SEG_0]

ILL_CASE_0 = [ # spatial_dims = 1
{
"spatial_dims": 1,
Expand Down Expand Up @@ -243,6 +255,15 @@

ILL_CASES_IN_SHAPE = [ILL_CASES_IN_SHAPE_0, ILL_CASES_IN_SHAPE_1]

ILL_CASE_SEG_SHAPE_0 = [ # moving_seg and moving image shape not match
{"spatial_dims": 3},
(1, 1, 96, 96, 48),
(1, 1, 96, 96, 48),
(1, 2, 80, 96, 48),
]

ILL_CASES_SEG_SHAPE = [ILL_CASE_SEG_SHAPE_0]


class TestVOXELMORPH(unittest.TestCase):
@parameterized.expand(CASES)
Expand All @@ -252,6 +273,28 @@ def test_shape(self, input_param, input_shape, expected_shape):
result = net.forward(torch.randn(input_shape).to(device))
self.assertEqual(result.shape, expected_shape)

@parameterized.expand(CASES_SEG)
def test_shape_seg(
self,
input_param,
moving_shape,
fixed_shape,
moving_seg_shape,
expected_warped_moving_shape,
expected_warped_moving_seg_shape,
expected_ddf_shape,
):
net = VoxelMorph(**input_param).to(device)
with eval_mode(net):
warped_moving, warped_moving_seg, ddf = net.forward(
torch.randn(moving_shape).to(device),
torch.randn(fixed_shape).to(device),
torch.randn(moving_seg_shape).to(device),
)
self.assertEqual(warped_moving.shape, expected_warped_moving_shape)
self.assertEqual(warped_moving_seg.shape, expected_warped_moving_seg_shape)
self.assertEqual(ddf.shape, expected_ddf_shape)

def test_script(self):
net = VoxelMorphUNet(
spatial_dims=2,
Expand All @@ -275,6 +318,17 @@ def test_ill_input_shape(self, input_param, moving_shape, fixed_shape):
with eval_mode(net):
_ = net.forward(torch.randn(moving_shape).to(device), torch.randn(fixed_shape).to(device))

@parameterized.expand(ILL_CASES_SEG_SHAPE)
def test_ill_input_seg_shape(self, input_param, moving_shape, fixed_shape, moving_seg_shape):
with self.assertRaises((ValueError, RuntimeError)):
net = VoxelMorph(**input_param).to(device)
with eval_mode(net):
_ = net.forward(
torch.randn(moving_shape).to(device),
torch.randn(fixed_shape).to(device),
torch.randn(moving_seg_shape).to(device),
)


if __name__ == "__main__":
unittest.main()
Loading