-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Replacement Apply and Resample #5436
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
21 commits
Select commit
Hold shift + click to select a range
d557549
Commit of functionality on replacement lr_apply branch due to issues …
atbenmurray 49e7ed6
adding functional tests for flip, resize and spacing
atbenmurray 61dbe11
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 9c11d4d
Adding additional spatial functional unit tests
atbenmurray 77d7c31
Work on apply and matmul tests
atbenmurray 828a7d8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 23a05fa
Adding tests for matmul and matrix_matrix
atbenmurray a67ef17
Removing spatial functional transforms from this PR
atbenmurray b25a6cb
Removing croppad functional transforms from this PR
atbenmurray 42f33a2
Removing spatial functional tests as they are no longer part of this PR
atbenmurray 1177263
[MONAI] code formatting
monai-bot 1492f40
minor updates
wyli 2d4122b
apply/matmul/resample MVP
wyli 6c4db35
rearrange modules, adds back matrix/grid types
wyli 9eb5279
matmul -> combine_transforms
wyli 17574cb
remove unused
wyli e067b6e
adds docstrings
wyli c5aa26c
adds testing
wyli 3c2d1ec
DDF -> DisplacementField
wyli 3f7e679
Merge remote-tracking branch 'upstream/dev' into lr_apply_2
wyli 1a124ba
Merge branch 'dev' into lr_apply_2
wyli File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional, Union | ||
|
||
import torch | ||
|
||
from monai.data.meta_tensor import MetaTensor | ||
from monai.data.utils import to_affine_nd | ||
from monai.transforms.lazy.utils import ( | ||
affine_from_pending, | ||
combine_transforms, | ||
is_compatible_apply_kwargs, | ||
kwargs_from_pending, | ||
resample, | ||
) | ||
|
||
__all__ = ["apply"] | ||
|
||
|
||
def apply(data: Union[torch.Tensor, MetaTensor], pending: Optional[list] = None): | ||
""" | ||
This method applies pending transforms to `data` tensors. | ||
|
||
Args: | ||
data: A torch Tensor or a monai MetaTensor. | ||
pending: pending transforms. This must be set if data is a Tensor, but is optional if data is a MetaTensor. | ||
""" | ||
if isinstance(data, MetaTensor) and pending is None: | ||
pending = data.pending_operations | ||
pending = [] if pending is None else pending | ||
|
||
if not pending: | ||
return data | ||
|
||
cumulative_xform = affine_from_pending(pending[0]) | ||
cur_kwargs = kwargs_from_pending(pending[0]) | ||
|
||
for p in pending[1:]: | ||
new_kwargs = kwargs_from_pending(p) | ||
if not is_compatible_apply_kwargs(cur_kwargs, new_kwargs): | ||
# carry out an intermediate resample here due to incompatibility between arguments | ||
data = resample(data, cumulative_xform, cur_kwargs) | ||
next_matrix = affine_from_pending(p) | ||
cumulative_xform = combine_transforms(cumulative_xform, next_matrix) | ||
cur_kwargs.update(new_kwargs) | ||
data = resample(data, cumulative_xform, cur_kwargs) | ||
if isinstance(data, MetaTensor): | ||
data.clear_pending_operations() | ||
data.affine = data.affine @ to_affine_nd(3, cumulative_xform) | ||
for p in pending: | ||
data.push_applied_operation(p) | ||
|
||
return data, pending |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,125 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import Optional | ||
|
||
import numpy as np | ||
import torch | ||
|
||
import monai | ||
from monai.config import NdarrayOrTensor | ||
from monai.utils import LazyAttr, convert_to_tensor | ||
|
||
__all__ = ["resample", "combine_transforms"] | ||
|
||
|
||
class Affine: | ||
"""A class to represent an affine transform matrix.""" | ||
|
||
__slots__ = ("data",) | ||
|
||
def __init__(self, data): | ||
self.data = data | ||
|
||
@staticmethod | ||
def is_affine_shaped(data): | ||
"""Check if the data is an affine matrix.""" | ||
if isinstance(data, Affine): | ||
return True | ||
if isinstance(data, DisplacementField): | ||
return False | ||
if not hasattr(data, "shape") or len(data.shape) < 2: | ||
return False | ||
return data.shape[-1] in (3, 4) and data.shape[-2] in (3, 4) and data.shape[-1] == data.shape[-2] | ||
|
||
|
||
class DisplacementField: | ||
"""A class to represent a dense displacement field.""" | ||
|
||
__slots__ = ("data",) | ||
|
||
def __init__(self, data): | ||
self.data = data | ||
|
||
@staticmethod | ||
def is_ddf_shaped(data): | ||
"""Check if the data is a DDF.""" | ||
if isinstance(data, DisplacementField): | ||
return True | ||
if isinstance(data, Affine): | ||
return False | ||
if not hasattr(data, "shape") or len(data.shape) < 3: | ||
return False | ||
return not Affine.is_affine_shaped(data) | ||
|
||
|
||
def combine_transforms(left: torch.Tensor, right: torch.Tensor) -> torch.Tensor: | ||
"""Given transforms A and B to be applied to x, return the combined transform (AB), so that A(B(x)) becomes AB(x)""" | ||
if Affine.is_affine_shaped(left) and Affine.is_affine_shaped(right): # linear transforms | ||
left = convert_to_tensor(left.data if isinstance(left, Affine) else left, wrap_sequence=True) | ||
right = convert_to_tensor(right.data if isinstance(right, Affine) else right, wrap_sequence=True) | ||
return torch.matmul(left, right) | ||
if DisplacementField.is_ddf_shaped(left) and DisplacementField.is_ddf_shaped( | ||
right | ||
): # adds DDFs, do we need metadata if metatensor input? | ||
left = convert_to_tensor(left.data if isinstance(left, DisplacementField) else left, wrap_sequence=True) | ||
right = convert_to_tensor(right.data if isinstance(right, DisplacementField) else right, wrap_sequence=True) | ||
return left + right | ||
raise NotImplementedError | ||
|
||
|
||
def affine_from_pending(pending_item): | ||
"""Extract the affine matrix from a pending transform item.""" | ||
if isinstance(pending_item, (torch.Tensor, np.ndarray)): | ||
return pending_item | ||
if isinstance(pending_item, dict): | ||
return pending_item[LazyAttr.AFFINE] | ||
return pending_item | ||
|
||
|
||
def kwargs_from_pending(pending_item): | ||
"""Extract kwargs from a pending transform item.""" | ||
if not isinstance(pending_item, dict): | ||
return {} | ||
ret = { | ||
LazyAttr.INTERP_MODE: pending_item.get(LazyAttr.INTERP_MODE, None), # interpolation mode | ||
LazyAttr.PADDING_MODE: pending_item.get(LazyAttr.PADDING_MODE, None), # padding mode | ||
} | ||
if LazyAttr.SHAPE in pending_item: | ||
ret[LazyAttr.SHAPE] = pending_item[LazyAttr.SHAPE] | ||
if LazyAttr.DTYPE in pending_item: | ||
ret[LazyAttr.DTYPE] = pending_item[LazyAttr.DTYPE] | ||
return ret | ||
|
||
|
||
def is_compatible_apply_kwargs(kwargs_1, kwargs_2): | ||
"""Check if two sets of kwargs are compatible (to be combined in `apply`).""" | ||
return True | ||
|
||
|
||
def resample(data: torch.Tensor, matrix: NdarrayOrTensor, kwargs: Optional[dict] = None): | ||
""" | ||
This is a minimal implementation of resample that always uses Affine. | ||
""" | ||
if not Affine.is_affine_shaped(matrix): | ||
raise NotImplementedError("calling dense grid resample API not implemented") | ||
kwargs = {} if kwargs is None else kwargs | ||
init_kwargs = { | ||
"spatial_size": kwargs.pop(LazyAttr.SHAPE, data.shape)[1:], | ||
"dtype": kwargs.pop(LazyAttr.DTYPE, data.dtype), | ||
} | ||
call_kwargs = { | ||
"mode": kwargs.pop(LazyAttr.INTERP_MODE, None), | ||
"padding_mode": kwargs.pop(LazyAttr.PADDING_MODE, None), | ||
} | ||
resampler = monai.transforms.Affine(affine=matrix, image_only=True, **init_kwargs) | ||
with resampler.trace_transform(False): # don't track this transform in `data` | ||
return resampler(img=data, **call_kwargs) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import numpy as np | ||
import torch | ||
|
||
from monai.transforms.lazy.functional import apply | ||
from monai.transforms.utils import create_rotate | ||
from monai.utils import LazyAttr, convert_to_tensor | ||
from tests.utils import get_arange_img | ||
|
||
|
||
def single_2d_transform_cases(): | ||
return [ | ||
( | ||
torch.as_tensor(get_arange_img((32, 32))), | ||
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 4)}, {LazyAttr.AFFINE: create_rotate(2, -np.pi / 4)}], | ||
(1, 32, 32), | ||
), | ||
(torch.as_tensor(get_arange_img((32, 32))), [create_rotate(2, np.pi / 2)], (1, 32, 32)), | ||
( | ||
torch.as_tensor(get_arange_img((16, 16))), | ||
[{LazyAttr.AFFINE: create_rotate(2, np.pi / 2), LazyAttr.SHAPE: (1, 45, 45)}], | ||
(1, 45, 45), | ||
), | ||
] | ||
|
||
|
||
class TestApply(unittest.TestCase): | ||
def _test_apply_impl(self, tensor, pending_transforms, expected_shape): | ||
result = apply(tensor, pending_transforms) | ||
self.assertListEqual(result[1], pending_transforms) | ||
self.assertEqual(result[0].shape, expected_shape) | ||
|
||
def _test_apply_metatensor_impl(self, tensor, pending_transforms, expected_shape, pending_as_parameter): | ||
tensor_ = convert_to_tensor(tensor, track_meta=True) | ||
if pending_as_parameter: | ||
result, transforms = apply(tensor_, pending_transforms) | ||
else: | ||
for p in pending_transforms: | ||
tensor_.push_pending_operation(p) | ||
result, transforms = apply(tensor_) | ||
self.assertEqual(result.shape, expected_shape) | ||
|
||
SINGLE_TRANSFORM_CASES = single_2d_transform_cases() | ||
|
||
def test_apply_single_transform(self): | ||
for case in self.SINGLE_TRANSFORM_CASES: | ||
self._test_apply_impl(*case) | ||
|
||
def test_apply_single_transform_metatensor(self): | ||
for case in self.SINGLE_TRANSFORM_CASES: | ||
self._test_apply_metatensor_impl(*case, False) | ||
|
||
def test_apply_single_transform_metatensor_override(self): | ||
for case in self.SINGLE_TRANSFORM_CASES: | ||
self._test_apply_metatensor_impl(*case, True) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Copyright (c) MONAI Consortium | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
|
||
import torch | ||
from parameterized import parameterized | ||
|
||
from monai.transforms.lazy.functional import resample | ||
from monai.utils import convert_to_tensor | ||
from tests.utils import assert_allclose, get_arange_img | ||
|
||
|
||
def rotate_90_2d(): | ||
t = torch.eye(3) | ||
t[:, 0] = torch.FloatTensor([0, -1, 0]) | ||
t[:, 1] = torch.FloatTensor([1, 0, 0]) | ||
return t | ||
|
||
|
||
RESAMPLE_FUNCTION_CASES = [(get_arange_img((3, 3)), rotate_90_2d(), [[2, 5, 8], [1, 4, 7], [0, 3, 6]])] | ||
|
||
|
||
class TestResampleFunction(unittest.TestCase): | ||
@parameterized.expand(RESAMPLE_FUNCTION_CASES) | ||
def test_resample_function_impl(self, img, matrix, expected): | ||
out = resample(convert_to_tensor(img), matrix) | ||
assert_allclose(out[0], expected, type_test=False) | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.