Skip to content

Commit 2c7ba9e

Browse files
atbenmurraypre-commit-ci[bot]monai-botwyli
authored
Replacement Apply and Resample (#5436)
Signed-off-by: Ben Murray <ben.murray@gmail.com> ### Description This is part of the work towards #4855. It adds: - a lazy `apply` method - A transform-like wrapper for `apply` called `Apply` ~- `MetaMatrix` and related functionality to represent abstracted grid and matrix transforms with metadata~ - A universal `resample` function that can be used to apply grid / matrix transforms ~- Functional spatial and croppad implementations that define but don't apply transforms~ ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Ben Murray <ben.murray@gmail.com> Signed-off-by: monai-bot <monai.miccai2019@gmail.com> Signed-off-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: monai-bot <monai.miccai2019@gmail.com> Co-authored-by: Wenqi Li <wenqil@nvidia.com> Co-authored-by: Wenqi Li <831580+wyli@users.noreply.github.com>
1 parent 94dbc3c commit 2c7ba9e

File tree

9 files changed

+324
-0
lines changed

9 files changed

+324
-0
lines changed

monai/data/meta_obj.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,9 @@ def push_pending_operation(self, t: Any) -> None:
213213
def pop_pending_operation(self) -> Any:
214214
return self._pending_operations.pop()
215215

216+
def clear_pending_operations(self) -> Any:
217+
self._pending_operations = MetaObj.get_default_applied_operations()
218+
216219
@property
217220
def is_batch(self) -> bool:
218221
"""Return whether object is part of batch or not."""

monai/transforms/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,8 @@
227227
from .inverse_batch_transform import BatchInverseTransform, Decollated, DecollateD, DecollateDict
228228
from .io.array import SUPPORTED_READERS, LoadImage, SaveImage
229229
from .io.dictionary import LoadImaged, LoadImageD, LoadImageDict, SaveImaged, SaveImageD, SaveImageDict
230+
from .lazy.functional import apply
231+
from .lazy.utils import combine_transforms, resample
230232
from .meta_utility.dictionary import (
231233
FromMetaTensord,
232234
FromMetaTensorD,

monai/transforms/lazy/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.

monai/transforms/lazy/functional.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Optional, Union
13+
14+
import torch
15+
16+
from monai.data.meta_tensor import MetaTensor
17+
from monai.data.utils import to_affine_nd
18+
from monai.transforms.lazy.utils import (
19+
affine_from_pending,
20+
combine_transforms,
21+
is_compatible_apply_kwargs,
22+
kwargs_from_pending,
23+
resample,
24+
)
25+
26+
__all__ = ["apply"]
27+
28+
29+
def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None):
30+
"""
31+
This method applies pending transforms to `data` tensors.
32+
33+
Args:
34+
data: A torch Tensor or a monai MetaTensor.
35+
pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor.
36+
"""
37+
if isinstance(data, MetaTensor) and pending is None:
38+
pending = data.pending_operations
39+
pending = [] if pending is None else pending
40+
41+
if not pending:
42+
return data
43+
44+
cumulative_xform = affine_from_pending(pending[0])
45+
cur_kwargs = kwargs_from_pending(pending[0])
46+
47+
for p in pending[1:]:
48+
new_kwargs = kwargs_from_pending(p)
49+
if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs):
50+
# carry out an intermediate resample here due to incompatibility between arguments
51+
data = resample(data, cumulative_xform, cur_kwargs)
52+
next_matrix = affine_from_pending(p)
53+
cumulative_xform = combine_transforms(cumulative_xform, next_matrix)
54+
cur_kwargs.update(new_kwargs)
55+
data = resample(data, cumulative_xform, cur_kwargs)
56+
if isinstance(data, MetaTensor):
57+
data.clear_pending_operations()
58+
data.affine = data.affine @ to_affine_nd(3, cumulative_xform)
59+
for p in pending:
60+
data.push_applied_operation(p)
61+
62+
return data, pending

monai/transforms/lazy/utils.py

Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from typing import Optional
13+
14+
import numpy as np
15+
import torch
16+
17+
import monai
18+
from monai.config import NdarrayOrTensor
19+
from monai.utils import LazyAttr, convert_to_tensor
20+
21+
__all__ = ["resample", "combine_transforms"]
22+
23+
24+
class Affine:
25+
"""A class to represent an affine transform matrix."""
26+
27+
__slots__ = ("data",)
28+
29+
def __init__(self, data):
30+
self.data = data
31+
32+
@staticmethod
33+
def is_affine_shaped(data):
34+
"""Check if the data is an affine matrix."""
35+
if isinstance(data, Affine):
36+
return True
37+
if isinstance(data, DisplacementField):
38+
return False
39+
if not hasattr(data, "shape") or len(data.shape) < 2:
40+
return False
41+
return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2]
42+
43+
44+
class DisplacementField:
45+
"""A class to represent a dense displacement field."""
46+
47+
__slots__ = ("data",)
48+
49+
def __init__(self, data):
50+
self.data = data
51+
52+
@staticmethod
53+
def is_ddf_shaped(data):
54+
"""Check if the data is a DDF."""
55+
if isinstance(data, DisplacementField):
56+
return True
57+
if isinstance(data, Affine):
58+
return False
59+
if not hasattr(data, "shape") or len(data.shape) < 3:
60+
return False
61+
return not Affine.is_affine_shaped(data)
62+
63+
64+
def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor:
65+
"""Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)"""
66+
if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms
67+
left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True)
68+
right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True)
69+
return torch.matmul(left, right)
70+
if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped(
71+
right
72+
): # adds DDFs, do we need metadata if metatensor input?
73+
left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True)
74+
right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True)
75+
return left + right
76+
raise NotImplementedError
77+
78+
79+
def affine_from_pending(pending_item):
80+
"""Extract the affine matrix from a pending transform item."""
81+
if isinstance(pending_item, (torch.Tensor, np.ndarray)):
82+
return pending_item
83+
if isinstance(pending_item, dict):
84+
return pending_item[LazyAttr.AFFINE]
85+
return pending_item
86+
87+
88+
def kwargs_from_pending(pending_item):
89+
"""Extract kwargs from a pending transform item."""
90+
if not isinstance(pending_item, dict):
91+
return {}
92+
ret = {
93+
LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode
94+
LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode
95+
}
96+
if LazyAttr.SHAPE in pending_item:
97+
ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE]
98+
if LazyAttr.DTYPE in pending_item:
99+
ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE]
100+
return ret
101+
102+
103+
def is_compatible_apply_kwargs(kwargs_1, kwargs_2):
104+
"""Check if two sets of kwargs are compatible (to be combined in `apply`)."""
105+
return True
106+
107+
108+
def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None):
109+
"""
110+
This is a minimal implementation of resample that always uses Affine.
111+
"""
112+
if not Affine.is_affine_shaped(matrix):
113+
raise NotImplementedError("calling dense grid resample API not implemented")
114+
kwargs = {} if kwargs is None else kwargs
115+
init_kwargs = {
116+
"spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:],
117+
"dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype),
118+
}
119+
call_kwargs = {
120+
"mode": kwargs.pop(LazyAttr.INTERP_MODE, None),
121+
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None),
122+
}
123+
resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs)
124+
with resampler.trace_transform(False): # don't track this transform in `data`
125+
return resampler(img=data, **call_kwargs)

monai/utils/enums.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -630,3 +630,4 @@ class LazyAttr(StrEnum):
630630
AFFINE = "lazy_affine"
631631
PADDING_MODE = "lazy_padding_mode"
632632
INTERP_MODE = "lazy_interpolation_mode"
633+
DTYPE = "lazy_dtype"

tests/test_apply.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
import torch
16+
17+
from monai.transforms.lazy.functional import apply
18+
from monai.transforms.utils import create_rotate
19+
from monai.utils import LazyAttr, convert_to_tensor
20+
from tests.utils import get_arange_img
21+
22+
23+
def single_2d_transform_cases():
24+
return [
25+
(
26+
torch.as_tensor(get_arange_img((32, 32))),
27+
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}, {LazyAttr.AFFINE: create_rotate(2, -np.pi / 4)}],
28+
(1, 32, 32),
29+
),
30+
(torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)),
31+
(
32+
torch.as_tensor(get_arange_img((16, 16))),
33+
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}],
34+
(1, 45, 45),
35+
),
36+
]
37+
38+
39+
class TestApply(unittest.TestCase):
40+
def _test_apply_impl(self, tensor, pending_transforms, expected_shape):
41+
result = apply(tensor, pending_transforms)
42+
self.assertListEqual(result[1], pending_transforms)
43+
self.assertEqual(result[0].shape, expected_shape)
44+
45+
def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter):
46+
tensor_ = convert_to_tensor(tensor, track_meta=True)
47+
if pending_as_parameter:
48+
result, transforms = apply(tensor_, pending_transforms)
49+
else:
50+
for p in pending_transforms:
51+
tensor_.push_pending_operation(p)
52+
result, transforms = apply(tensor_)
53+
self.assertEqual(result.shape, expected_shape)
54+
55+
SINGLE_TRANSFORM_CASES = single_2d_transform_cases()
56+
57+
def test_apply_single_transform(self):
58+
for case in self.SINGLE_TRANSFORM_CASES:
59+
self._test_apply_impl(*case)
60+
61+
def test_apply_single_transform_metatensor(self):
62+
for case in self.SINGLE_TRANSFORM_CASES:
63+
self._test_apply_metatensor_impl(*case, False)
64+
65+
def test_apply_single_transform_metatensor_override(self):
66+
for case in self.SINGLE_TRANSFORM_CASES:
67+
self._test_apply_metatensor_impl(*case, True)
68+
69+
70+
if __name__ == "__main__":
71+
unittest.main()

tests/test_resample.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
from parameterized import parameterized
16+
17+
from monai.transforms.lazy.functional import resample
18+
from monai.utils import convert_to_tensor
19+
from tests.utils import assert_allclose, get_arange_img
20+
21+
22+
def rotate_90_2d():
23+
t = torch.eye(3)
24+
t[:, 0] = torch.FloatTensor([0, -1, 0])
25+
t[:, 1] = torch.FloatTensor([1, 0, 0])
26+
return t
27+
28+
29+
RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])]
30+
31+
32+
class TestResampleFunction(unittest.TestCase):
33+
@parameterized.expand(RESAMPLE_FUNCTION_CASES)
34+
def test_resample_function_impl(self, img, matrix, expected):
35+
out = resample(convert_to_tensor(img), matrix)
36+
assert_allclose(out[0], expected, type_test=False)
37+
38+
39+
if __name__ == "__main__":
40+
unittest.main()

tests/utils.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,16 @@ def make_rand_affine(ndim: int = 3, random_state: Optional[np.random.RandomState
348348
return af
349349

350350

351+
def get_arange_img(size, dtype=np.float32, offset=0):
352+
"""
353+
Returns an image as a numpy array (complete with channel as dim 0)
354+
with contents that iterate like an arange.
355+
"""
356+
n_elem = np.prod(size)
357+
img = np.arange(offset, offset + n_elem, dtype=dtype).reshape(size)
358+
return np.expand_dims(img, 0)
359+
360+
351361
class DistTestCase(unittest.TestCase):
352362
"""
353363
testcase without _outcome, so that it's picklable.

0 commit comments

Comments
 (0)