Skip to content

Commit f201883

Browse files
authored
4922 adding a minimal lazy transform interface (#5407)
follow-up of #4922 ### Description - minimal interface to track the pending transforms via metatensor - transforms.Flip is modified as an example for discussion - discussion points: - maintaining `pending_operations` and `applied_operations` independently? - the data structure for `pending_operations` element is a python dictionary - transform "functional" refactoring ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 350fe6e commit f201883

File tree

5 files changed

+54
-2
lines changed

5 files changed

+54
-2
lines changed

monai/data/meta_obj.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,6 +82,7 @@ class MetaObj:
8282
def __init__(self):
8383
self._meta: dict = MetaObj.get_default_meta()
8484
self._applied_operations: list = MetaObj.get_default_applied_operations()
85+
self._pending_operations: list = MetaObj.get_default_applied_operations() # the same default as applied_ops
8586
self._is_batch: bool = False
8687

8788
@staticmethod
@@ -199,6 +200,19 @@ def push_applied_operation(self, t: Any) -> None:
199200
def pop_applied_operation(self) -> Any:
200201
return self._applied_operations.pop()
201202

203+
@property
204+
def pending_operations(self) -> list[dict]:
205+
"""Get the pending operations. Defaults to ``[]``."""
206+
if hasattr(self, "_pending_operations"):
207+
return self._pending_operations
208+
return MetaObj.get_default_applied_operations() # the same default as applied_ops
209+
210+
def push_pending_operation(self, t: Any) -> None:
211+
self._pending_operations.append(t)
212+
213+
def pop_pending_operation(self) -> Any:
214+
return self._pending_operations.pop()
215+
202216
@property
203217
def is_batch(self) -> bool:
204218
"""Return whether object is part of batch or not."""

monai/data/meta_tensor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,8 @@
2323
from monai.data.meta_obj import MetaObj, get_track_meta
2424
from monai.data.utils import affine_to_spacing, decollate_batch, list_data_collate, remove_extra_metadata
2525
from monai.utils import look_up_option
26-
from monai.utils.enums import MetaKeys, PostFix, SpaceKeys
27-
from monai.utils.type_conversion import convert_data_type, convert_to_tensor
26+
from monai.utils.enums import LazyAttr, MetaKeys, PostFix, SpaceKeys
27+
from monai.utils.type_conversion import convert_data_type, convert_to_numpy, convert_to_tensor
2828

2929
__all__ = ["MetaTensor"]
3030

@@ -445,6 +445,20 @@ def pixdim(self):
445445
return [affine_to_spacing(a) for a in self.affine]
446446
return affine_to_spacing(self.affine)
447447

448+
def peek_pending_shape(self):
449+
"""Get the currently expected spatial shape as if all the pending operations are executed."""
450+
res = None
451+
if self.pending_operations:
452+
res = self.pending_operations[-1].get(LazyAttr.SHAPE, None)
453+
# default to spatial shape (assuming channel-first input)
454+
return tuple(convert_to_numpy(self.shape, wrap_sequence=True).tolist()[1:]) if res is None else res
455+
456+
def peek_pending_affine(self):
457+
res = None
458+
if self.pending_operations:
459+
res = self.pending_operations[-1].get(LazyAttr.AFFINE, None)
460+
return self.affine if res is None else res
461+
448462
def new_empty(self, size, dtype=None, device=None, requires_grad=False):
449463
"""
450464
must be defined for deepcopy to work

monai/utils/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
InterpolateMode,
3535
InverseKeys,
3636
JITMetadataKeys,
37+
LazyAttr,
3738
LossReduction,
3839
MetaKeys,
3940
Method,

monai/utils/enums.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@
5454
"AlgoEnsembleKeys",
5555
"HoVerNetMode",
5656
"HoVerNetBranch",
57+
"LazyAttr",
5758
]
5859

5960

@@ -616,3 +617,16 @@ class HoVerNetBranch(StrEnum):
616617
HV = "horizontal_vertical"
617618
NP = "nucleus_prediction"
618619
NC = "type_prediction"
620+
621+
622+
class LazyAttr(StrEnum):
623+
"""
624+
MetaTensor with pending operations requires some key attributes tracked especially when the primary array
625+
is not up-to-date due to lazy evaluation.
626+
This class specifies the set of key attributes to be tracked for each MetaTensor.
627+
"""
628+
629+
SHAPE = "lazy_shape" # spatial shape
630+
AFFINE = "lazy_affine"
631+
PADDING_MODE = "lazy_padding_mode"
632+
INTERP_MODE = "lazy_interpolation_mode"

tests/test_meta_tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,15 @@ def test_construct_with_pre_applied_transforms(self):
495495
m = MetaTensor(im, applied_operations=data["im"].applied_operations)
496496
self.assertEqual(len(m.applied_operations), len(tr.transforms))
497497

498+
def test_pending_ops(self):
499+
m, _ = self.get_im()
500+
self.assertEqual(m.pending_operations, [])
501+
self.assertEqual(m.peek_pending_shape(), (10, 8))
502+
self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)
503+
m.push_pending_operation({})
504+
self.assertEqual(m.peek_pending_shape(), (10, 8))
505+
self.assertIsInstance(m.peek_pending_affine(), torch.Tensor)
506+
498507
@parameterized.expand(TESTS)
499508
def test_multiprocessing(self, device=None, dtype=None):
500509
"""multiprocessing sharing with 'device' and 'dtype'"""

0 commit comments

Comments
 (0)