Skip to content

Commit 8b1f0c3

Browse files
authored
Type classes for lazy resampling (#5418)
Signed-off-by: Ben Murray <ben.murray@gmail.com> Fixes # . ### Description A few sentences describing the changes proposed in this pull request. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Ben Murray <ben.murray@gmail.com>
1 parent 3f82016 commit 8b1f0c3

File tree

5 files changed

+149
-4
lines changed

5 files changed

+149
-4
lines changed

docs/source/transforms.rst

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,31 @@ Generic Interfaces
2222
:members:
2323
:special-members: __call__
2424

25+
`RandomizableTrait`
26+
^^^^^^^^^^^^^^^^^^^
27+
.. autoclass:: RandomizableTrait
28+
:members:
29+
30+
`LazyTrait`
31+
^^^^^^^^^^^
32+
.. autoclass:: LazyTrait
33+
:members:
34+
35+
`MultiSampleTrait`
36+
^^^^^^^^^^^^^^^^^^
37+
.. autoclass:: MultiSampleTrait
38+
:members:
39+
2540
`Randomizable`
2641
^^^^^^^^^^^^^^
2742
.. autoclass:: Randomizable
2843
:members:
2944

45+
`LazyTransform`
46+
^^^^^^^^^^^^^^^
47+
.. autoclass:: LazyTransform
48+
:members:
49+
3050
`RandomizableTransform`
3151
^^^^^^^^^^^^^^^^^^^^^^^
3252
.. autoclass:: RandomizableTransform

monai/transforms/__init__.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -449,7 +449,18 @@
449449
ZoomD,
450450
ZoomDict,
451451
)
452-
from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
452+
from .transform import (
453+
LazyTrait,
454+
LazyTransform,
455+
MapTransform,
456+
MultiSampleTrait,
457+
Randomizable,
458+
RandomizableTrait,
459+
RandomizableTransform,
460+
ThreadUnsafe,
461+
Transform,
462+
apply_transform,
463+
)
453464
from .utility.array import (
454465
AddChannel,
455466
AddCoordinateChannels,

monai/transforms/transform.py

Lines changed: 83 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,18 @@
2626
from monai.utils.enums import TransformBackends
2727
from monai.utils.misc import MONAIEnvVars
2828

29-
__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]
29+
__all__ = [
30+
"ThreadUnsafe",
31+
"apply_transform",
32+
"LazyTrait",
33+
"RandomizableTrait",
34+
"MultiSampleTrait",
35+
"Randomizable",
36+
"LazyTransform",
37+
"RandomizableTransform",
38+
"Transform",
39+
"MapTransform",
40+
]
3041

3142
ReturnType = TypeVar("ReturnType")
3243

@@ -118,6 +129,56 @@ def _log_stats(data, prefix: Optional[str] = "Data"):
118129
raise RuntimeError(f"applying transform {transform}") from e
119130

120131

132+
class LazyTrait:
133+
"""
134+
An interface to indicate that the transform has the capability to execute using
135+
MONAI's lazy resampling feature. In order to do this, the implementing class needs
136+
to be able to describe its operation as an affine matrix or grid with accompanying metadata.
137+
This interface can be extended from by people adapting transforms to the MONAI framework as
138+
well as by implementors of MONAI transforms.
139+
"""
140+
141+
@property
142+
def lazy_evaluation(self):
143+
"""
144+
Get whether lazy_evaluation is enabled for this transform instance.
145+
Returns:
146+
True if the transform is operating in a lazy fashion, False if not.
147+
"""
148+
raise NotImplementedError()
149+
150+
@lazy_evaluation.setter
151+
def lazy_evaluation(self, enabled: bool):
152+
"""
153+
Set whether lazy_evaluation is enabled for this transform instance.
154+
Args:
155+
enabled: True if the transform should operate in a lazy fashion, False if not.
156+
"""
157+
raise NotImplementedError()
158+
159+
160+
class RandomizableTrait:
161+
"""
162+
An interface to indicate that the transform has the capability to perform
163+
randomized transforms to the data that it is called upon. This interface
164+
can be extended from by people adapting transforms to the MONAI framework as well as by
165+
implementors of MONAI transforms.
166+
"""
167+
168+
pass
169+
170+
171+
class MultiSampleTrait:
172+
"""
173+
An interface to indicate that the transform has the capability to return multiple samples
174+
given an input, such as when performing random crops of a sample. This interface can be
175+
extended from by people adapting transforms to the MONAI framework as well as by implementors
176+
of MONAI transforms.
177+
"""
178+
179+
pass
180+
181+
121182
class ThreadUnsafe:
122183
"""
123184
A class to denote that the transform will mutate its member variables,
@@ -251,7 +312,27 @@ def __call__(self, data: Any):
251312
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
252313

253314

254-
class RandomizableTransform(Randomizable, Transform):
315+
class LazyTransform(Transform, LazyTrait):
316+
"""
317+
An implementation of functionality for lazy transforms that can be subclassed by array and
318+
dictionary transforms to simplify implementation of new lazy transforms.
319+
"""
320+
321+
def __init__(self, lazy_evaluation: Optional[bool] = True):
322+
self.lazy_evaluation = lazy_evaluation
323+
324+
@property
325+
def lazy_evaluation(self):
326+
return self.lazy_evaluation
327+
328+
@lazy_evaluation.setter
329+
def lazy_evaluation(self, lazy_evaluation: bool):
330+
if not isinstance(lazy_evaluation, bool):
331+
raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'")
332+
self.lazy_evaluation = lazy_evaluation
333+
334+
335+
class RandomizableTransform(Randomizable, Transform, RandomizableTrait):
255336
"""
256337
An interface for handling random state locally, currently based on a class variable `R`,
257338
which is an instance of `np.random.RandomState`.

monai/visualize/gradient_based.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_gra
9090
x.requires_grad = True
9191

9292
self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs)
93-
grad: torch.Tensor = x.grad.detach()
93+
grad: torch.Tensor = x.grad.detach() # type: ignore
9494
return grad
9595

9696
def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
from monai.transforms.transform import RandomizableTrait, RandomizableTransform
15+
16+
17+
class InheritsInterface(RandomizableTrait):
18+
pass
19+
20+
21+
class InheritsImplementation(RandomizableTransform):
22+
def __call__(self, data):
23+
return data
24+
25+
26+
class TestRandomizableTransformType(unittest.TestCase):
27+
def test_is_randomizable_transform_type(self):
28+
inst = InheritsInterface()
29+
self.assertIsInstance(inst, RandomizableTrait)
30+
31+
def test_set_random_state_randomizable_transform(self):
32+
inst = InheritsImplementation()
33+
inst.set_random_state(0)

0 commit comments

Comments
 (0)