Skip to content

Commit 9fd6d4c

Browse files
authored
improve SpacingD output shape compute stability (#6126)
for slightly different affine (atol < 1e-3) the shape output might be different ``` sum([1.49999991e+00 1.44128689e-04] * [237. 144.]) = sum([3.55499980e+02 2.07545313e-02]) = 355.52073414989167 sum([ 1.49999989e+00 -6.16167492e-10] * [237. 144.]) = sum([ 3.55499975e+02 -8.87281188e-08]) = 355.4999748785736 ``` - increase the stability of affine inv - check affine floating point types - ensure same shape when there's only tiny difference in pixdim ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 93a77b7 commit 9fd6d4c

File tree

5 files changed

+69
-28
lines changed

5 files changed

+69
-28
lines changed

monai/data/utils.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -875,15 +875,14 @@ def compute_shape_offset(
875875
in_coords = [(-0.5, dim - 0.5) if scale_extent else (0.0, dim - 1.0) for dim in shape]
876876
corners: np.ndarray = np.asarray(np.meshgrid(*in_coords, indexing="ij")).reshape((len(shape), -1))
877877
corners = np.concatenate((corners, np.ones_like(corners[:1])))
878-
corners = in_affine_ @ corners
879878
try:
880-
inv_mat = np.linalg.inv(out_affine_)
879+
corners_out = np.linalg.solve(out_affine_, in_affine_) @ corners
881880
except np.linalg.LinAlgError as e:
882881
raise ValueError(f"Affine {out_affine_} is not invertible") from e
883-
corners_out = inv_mat @ corners
882+
corners = in_affine_ @ corners
883+
all_dist = corners_out[:-1].copy()
884884
corners_out = corners_out[:-1] / corners_out[-1]
885885
out_shape = np.round(corners_out.ptp(axis=1)) if scale_extent else np.round(corners_out.ptp(axis=1) + 1.0)
886-
all_dist = inv_mat[:-1, :-1] @ corners[:-1, :]
887886
offset = None
888887
for i in range(corners.shape[1]):
889888
min_corner = np.min(all_dist - all_dist[:, i : i + 1], 1)

monai/networks/layers/spatial_transforms.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -537,6 +537,8 @@ def forward(
537537
theta = torch.cat([theta, pad_affine], dim=1)
538538
if tuple(theta.shape[1:]) not in ((3, 3), (4, 4)):
539539
raise ValueError(f"theta must be Nx3x3 or Nx4x4, got {theta.shape}.")
540+
if not torch.is_floating_point(theta):
541+
raise ValueError(f"theta must be floating point data, got {theta.dtype}")
540542

541543
# validate `src`
542544
if not isinstance(src, torch.Tensor):

monai/transforms/spatial/dictionary.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def __init__(
339339
recompute_affine: bool = False,
340340
min_pixdim: Sequence[float] | float | None = None,
341341
max_pixdim: Sequence[float] | float | None = None,
342+
ensure_same_shape: bool = True,
342343
allow_missing_keys: bool = False,
343344
) -> None:
344345
"""
@@ -396,6 +397,8 @@ def __init__(
396397
max_pixdim: maximal input spacing to be resampled. If provided, input image with a smaller spacing than this
397398
value will be kept in its original spacing (not be resampled to `pixdim`). Set it to `None` to use the
398399
value of `pixdim`. Default to `None`.
400+
ensure_same_shape: when the inputs have the same spatial shape, and almost the same pixdim,
401+
whether to ensure exactly the same output spatial shape. Default to True.
399402
allow_missing_keys: don't raise exception if key is missing.
400403
401404
"""
@@ -408,6 +411,7 @@ def __init__(
408411
self.align_corners = ensure_tuple_rep(align_corners, len(self.keys))
409412
self.dtype = ensure_tuple_rep(dtype, len(self.keys))
410413
self.scale_extent = ensure_tuple_rep(scale_extent, len(self.keys))
414+
self.ensure_same_shape = ensure_same_shape
411415

412416
@LazyTransform.lazy_evaluation.setter # type: ignore
413417
def lazy_evaluation(self, val: bool) -> None:
@@ -416,18 +420,30 @@ def lazy_evaluation(self, val: bool) -> None:
416420

417421
def __call__(self, data: Mapping[Hashable, torch.Tensor]) -> dict[Hashable, torch.Tensor]:
418422
d: dict = dict(data)
423+
424+
_init_shape, _pixdim, should_match = None, None, False
425+
output_shape_k = None # tracking output shape
426+
419427
for key, mode, padding_mode, align_corners, dtype, scale_extent in self.key_iterator(
420428
d, self.mode, self.padding_mode, self.align_corners, self.dtype, self.scale_extent
421429
):
422-
# resample array of each corresponding key
430+
if self.ensure_same_shape and isinstance(d[key], MetaTensor):
431+
if _init_shape is None and _pixdim is None:
432+
_init_shape, _pixdim = d[key].peek_pending_shape(), d[key].pixdim
433+
else:
434+
should_match = np.allclose(_init_shape, d[key].peek_pending_shape()) and np.allclose(
435+
_pixdim, d[key].pixdim, atol=1e-3
436+
)
423437
d[key] = self.spacing_transform(
424438
data_array=d[key],
425439
mode=mode,
426440
padding_mode=padding_mode,
427441
align_corners=align_corners,
428442
dtype=dtype,
429443
scale_extent=scale_extent,
444+
output_spatial_shape=output_shape_k if should_match else None,
430445
)
446+
output_shape_k = d[key].peek_pending_shape() if isinstance(d[key], MetaTensor) else d[key].shape[1:]
431447
return d
432448

433449
def inverse(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:

tests/test_global_mutual_information_loss.py

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -26,28 +26,28 @@
2626

2727
EXPECTED_VALUE = {
2828
"xyz_translation": [
29-
-1.5860259532928467,
30-
-0.5957175493240356,
31-
-0.3855515122413635,
32-
-0.28728482127189636,
33-
-0.23416118323802948,
34-
-0.19534644484519958,
35-
-0.17001715302467346,
36-
-0.15043553709983826,
37-
-0.1366637945175171,
38-
-0.12534910440444946,
29+
-1.5860257,
30+
-0.62433463,
31+
-0.38217825,
32+
-0.2905613,
33+
-0.23233329,
34+
-0.1961407,
35+
-0.16905619,
36+
-0.15100679,
37+
-0.13666219,
38+
-0.12635908,
3939
],
4040
"xyz_rotation": [
41-
-1.5860259532928467,
42-
-0.29977330565452576,
43-
-0.18411292135715485,
44-
-0.1582011878490448,
45-
-0.16107326745986938,
46-
-0.165723517537117,
47-
-0.1970357596874237,
48-
-0.1755618453025818,
49-
-0.17100191116333008,
50-
-0.17264796793460846,
41+
-1.5860257,
42+
-0.30265224,
43+
-0.18666176,
44+
-0.15887907,
45+
-0.1625064,
46+
-0.16603896,
47+
-0.19222091,
48+
-0.18158069,
49+
-0.167644,
50+
-0.16698098,
5151
],
5252
}
5353

@@ -84,7 +84,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
8484
numpy array of shape HWD
8585
"""
8686
transform_list = [
87-
transforms.LoadImaged(keys="img"),
87+
transforms.LoadImaged(keys="img", image_only=True),
8888
transforms.Affined(
8989
keys="img", translate_params=translate_params, rotate_params=rotate_params, device=None
9090
),
@@ -94,7 +94,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
9494
return transformation({"img": FILE_PATH})["img"]
9595

9696
a1 = transformation()
97-
a1 = torch.tensor(a1).unsqueeze(0).unsqueeze(0).to(device)
97+
a1 = a1.clone().unsqueeze(0).unsqueeze(0).to(device)
9898

9999
for mode in transform_params_dict:
100100
transform_params_list = transform_params_dict[mode]
@@ -104,7 +104,7 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
104104
translate_params=transform_params if "translation" in mode else (0.0, 0.0, 0.0),
105105
rotate_params=transform_params if "rotation" in mode else (0.0, 0.0, 0.0),
106106
)
107-
a2 = torch.tensor(a2).unsqueeze(0).unsqueeze(0).to(device)
107+
a2 = a2.clone().unsqueeze(0).unsqueeze(0).to(device)
108108
result = loss_fn(a2, a1).detach().cpu().numpy()
109109
np.testing.assert_allclose(result, expected_value, rtol=1e-3, atol=5e-3)
110110

tests/test_spacingd.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,30 @@ def test_orntd_torch(self, init_param, img: torch.Tensor, track_meta: bool, devi
134134
self.assertNotIsInstance(res, MetaTensor)
135135
self.assertNotEqual(img.shape, res.shape)
136136

137+
def test_space_same_shape(self):
138+
affine_1 = np.array(
139+
[
140+
[1.499277e00, 2.699563e-02, 3.805804e-02, -1.948635e02],
141+
[-2.685805e-02, 1.499757e00, -2.635604e-12, 4.438188e01],
142+
[-3.805194e-02, -5.999028e-04, 1.499517e00, 4.036536e01],
143+
[0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00],
144+
]
145+
)
146+
affine_2 = np.array(
147+
[
148+
[1.499275e00, 2.692252e-02, 3.805728e-02, -1.948635e02],
149+
[-2.693010e-02, 1.499758e00, -4.260525e-05, 4.438188e01],
150+
[-3.805190e-02, -6.406730e-04, 1.499517e00, 4.036536e01],
151+
[0.000000e00, 0.000000e00, 0.000000e00, 1.000000e00],
152+
]
153+
)
154+
img_1 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_1)
155+
img_2 = MetaTensor(np.zeros((1, 238, 145, 315)), affine=affine_2)
156+
out = Spacingd(("img_1", "img_2"), pixdim=1)({"img_1": img_1, "img_2": img_2})
157+
self.assertEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape True
158+
out = Spacingd(("img_1", "img_2"), pixdim=1, ensure_same_shape=False)({"img_1": img_1, "img_2": img_2})
159+
self.assertNotEqual(out["img_1"].shape, out["img_2"].shape) # ensure_same_shape False
160+
137161

138162
if __name__ == "__main__":
139163
unittest.main()

0 commit comments

Comments
 (0)