Skip to content

Type classes for lazy resampling #5418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions docs/source/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,31 @@ Generic Interfaces
:members:
:special-members: __call__

`RandomizableTrait`
^^^^^^^^^^^^^^^^^^^
.. autoclass:: RandomizableTrait
:members:

`LazyTrait`
^^^^^^^^^^^
.. autoclass:: LazyTrait
:members:

`MultiSampleTrait`
^^^^^^^^^^^^^^^^^^
.. autoclass:: MultiSampleTrait
:members:

`Randomizable`
^^^^^^^^^^^^^^
.. autoclass:: Randomizable
:members:

`LazyTransform`
^^^^^^^^^^^^^^^
.. autoclass:: LazyTransform
:members:

`RandomizableTransform`
^^^^^^^^^^^^^^^^^^^^^^^
.. autoclass:: RandomizableTransform
Expand Down
13 changes: 12 additions & 1 deletion monai/transforms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,18 @@
ZoomD,
ZoomDict,
)
from .transform import MapTransform, Randomizable, RandomizableTransform, ThreadUnsafe, Transform, apply_transform
from .transform import (
LazyTrait,
LazyTransform,
MapTransform,
MultiSampleTrait,
Randomizable,
RandomizableTrait,
RandomizableTransform,
ThreadUnsafe,
Transform,
apply_transform,
)
from .utility.array import (
AddChannel,
AddCoordinateChannels,
Expand Down
85 changes: 83 additions & 2 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@
from monai.utils.enums import TransformBackends
from monai.utils.misc import MONAIEnvVars

__all__ = ["ThreadUnsafe", "apply_transform", "Randomizable", "RandomizableTransform", "Transform", "MapTransform"]
__all__ = [
"ThreadUnsafe",
"apply_transform",
"LazyTrait",
"RandomizableTrait",
"MultiSampleTrait",
"Randomizable",
"LazyTransform",
"RandomizableTransform",
"Transform",
"MapTransform",
]

ReturnType = TypeVar("ReturnType")

Expand Down Expand Up @@ -118,6 +129,56 @@ def _log_stats(data, prefix: Optional[str] = "Data"):
raise RuntimeError(f"applying transform {transform}") from e


class LazyTrait:
"""
An interface to indicate that the transform has the capability to execute using
MONAI's lazy resampling feature. In order to do this, the implementing class needs
to be able to describe its operation as an affine matrix or grid with accompanying metadata.
This interface can be extended from by people adapting transforms to the MONAI framework as
well as by implementors of MONAI transforms.
"""

@property
def lazy_evaluation(self):
"""
Get whether lazy_evaluation is enabled for this transform instance.
Returns:
True if the transform is operating in a lazy fashion, False if not.
"""
raise NotImplementedError()

@lazy_evaluation.setter
def lazy_evaluation(self, enabled: bool):
"""
Set whether lazy_evaluation is enabled for this transform instance.
Args:
enabled: True if the transform should operate in a lazy fashion, False if not.
"""
raise NotImplementedError()


class RandomizableTrait:
"""
An interface to indicate that the transform has the capability to perform
randomized transforms to the data that it is called upon. This interface
can be extended from by people adapting transforms to the MONAI framework as well as by
implementors of MONAI transforms.
"""

pass


class MultiSampleTrait:
"""
An interface to indicate that the transform has the capability to return multiple samples
given an input, such as when performing random crops of a sample. This interface can be
extended from by people adapting transforms to the MONAI framework as well as by implementors
of MONAI transforms.
"""

pass


class ThreadUnsafe:
"""
A class to denote that the transform will mutate its member variables,
Expand Down Expand Up @@ -251,7 +312,27 @@ def __call__(self, data: Any):
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")


class RandomizableTransform(Randomizable, Transform):
class LazyTransform(Transform, LazyTrait):
"""
An implementation of functionality for lazy transforms that can be subclassed by array and
dictionary transforms to simplify implementation of new lazy transforms.
"""

def __init__(self, lazy_evaluation: Optional[bool] = True):
self.lazy_evaluation = lazy_evaluation

@property
def lazy_evaluation(self):
return self.lazy_evaluation

@lazy_evaluation.setter
def lazy_evaluation(self, lazy_evaluation: bool):
if not isinstance(lazy_evaluation, bool):
raise TypeError("'lazy_evaluation must be a bool but is of " f"type {type(lazy_evaluation)}'")
self.lazy_evaluation = lazy_evaluation


class RandomizableTransform(Randomizable, Transform, RandomizableTrait):
"""
An interface for handling random state locally, currently based on a class variable `R`,
which is an instance of `np.random.RandomState`.
Expand Down
2 changes: 1 addition & 1 deletion monai/visualize/gradient_based.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ def get_grad(self, x: torch.Tensor, index: torch.Tensor | int | None, retain_gra
x.requires_grad = True

self._model(x, class_idx=index, retain_graph=retain_graph, **kwargs)
grad: torch.Tensor = x.grad.detach()
grad: torch.Tensor = x.grad.detach() # type: ignore
return grad

def __call__(self, x: torch.Tensor, index: torch.Tensor | int | None = None, **kwargs) -> torch.Tensor:
Expand Down
33 changes: 33 additions & 0 deletions tests/test_randomizable_transform_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from monai.transforms.transform import RandomizableTrait, RandomizableTransform


class InheritsInterface(RandomizableTrait):
pass


class InheritsImplementation(RandomizableTransform):
def __call__(self, data):
return data


class TestRandomizableTransformType(unittest.TestCase):
def test_is_randomizable_transform_type(self):
inst = InheritsInterface()
self.assertIsInstance(inst, RandomizableTrait)

def test_set_random_state_randomizable_transform(self):
inst = InheritsImplementation()
inst.set_random_state(0)