Skip to content
Open
Changes from 10 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
e0cda55
added list extend to MultiSampleTrait
lukas-folle-snkeos Aug 8, 2025
1ad24af
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
35658f2
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
eeb7e12
fixed type errors
lukas-folle-snkeos Aug 8, 2025
c011103
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
6bb6110
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
e7a9185
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
a5d2261
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
b0dd089
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
7560a37
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Aug 8, 2025
77c138d
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Aug 8, 2025
7df8cb9
avoided breaking map_item functionality
lukas-folle-snkeos Aug 8, 2025
be46018
fixed wrong type annotation
lukas-folle-snkeos Aug 8, 2025
3aa1288
Merge branch 'dev' into dev
ericspod Aug 11, 2025
2d58774
added test for many multisample transforms; refactored code
lukas-folle-snkeos Sep 16, 2025
ee74761
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Sep 16, 2025
2c18f36
Merge branch 'dev' of github.com:lukas-folle-snkeos/MONAI into dev
lukas-folle-snkeos Sep 16, 2025
7ae8a26
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Sep 16, 2025
1d04028
DCO Remediation Commit for Lukas Folle <lukas.folle@snke.com>
lukas-folle-snkeos Sep 16, 2025
9377b63
added slight cleanup and additional test
lukas-folle-snkeos Sep 16, 2025
a56c0c3
Merge branch 'dev' into dev
lukas-folle-snkeos Sep 16, 2025
a8c9e24
Merge branch 'dev' into dev
lukas-folle-snkeos Oct 10, 2025
5135fb4
changed compose to explicit flattening
lukas-folle-snkeos Oct 10, 2025
fee6cd3
added documentation
lukas-folle-snkeos Oct 10, 2025
416584d
fixed doc build; fixed isort
lukas-folle-snkeos Oct 10, 2025
a8f3fe9
added type hints and fixed potential bug
lukas-folle-snkeos Oct 10, 2025
c707a2c
formatted
lukas-folle-snkeos Oct 10, 2025
2644f47
ignored mypy error
lukas-folle-snkeos Oct 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 77 additions & 25 deletions monai/transforms/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,22 @@ def _apply_transform(
"""
from monai.transforms.lazy.functional import apply_pending_transforms_in_order

data = apply_pending_transforms_in_order(transform, data, lazy, overrides, logger_name)
data = apply_pending_transforms_in_order(
transform, data, lazy, overrides, logger_name
)

if isinstance(data, tuple) and unpack_parameters:
return transform(*data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(*data)
return (
transform(*data, lazy=lazy)
if isinstance(transform, LazyTrait)
else transform(*data)
)

return transform(data, lazy=lazy) if isinstance(transform, LazyTrait) else transform(data)
return (
transform(data, lazy=lazy)
if isinstance(transform, LazyTrait)
else transform(data)
)


def apply_transform(
Expand Down Expand Up @@ -143,31 +153,49 @@ def apply_transform(
try:
map_items_ = int(map_items) if isinstance(map_items, bool) else map_items
if isinstance(data, (list, tuple)) and map_items_ > 0:
return [
apply_transform(transform, item, map_items_ - 1, unpack_items, log_stats, lazy, overrides)
for item in data
]
return _apply_transform(transform, data, unpack_items, lazy, overrides, log_stats)
res: list[ReturnType] = []
for item in data:
res_item = _apply_transform(
transform, item, unpack_items, lazy, overrides, log_stats
)
if isinstance(res_item, (list, tuple)):
res.extend(res_item)
else:
res.append(res_item)
return res
return _apply_transform(
transform, data, unpack_items, lazy, overrides, log_stats
)
except Exception as e:
# if in debug mode, don't swallow exception so that the breakpoint
# appears where the exception was raised.
if MONAIEnvVars.debug():
raise
if log_stats is not False and not isinstance(transform, transforms.compose.Compose):
if log_stats is not False and not isinstance(
transform, transforms.compose.Compose
):
# log the input data information of exact transform in the transform chain
if isinstance(log_stats, str):
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False, name=log_stats)
datastats = transforms.utility.array.DataStats(
data_shape=False, value_range=False, name=log_stats
)
else:
datastats = transforms.utility.array.DataStats(data_shape=False, value_range=False)
datastats = transforms.utility.array.DataStats(
data_shape=False, value_range=False
)
logger = logging.getLogger(datastats._logger_name)
logger.error(f"\n=== Transform input info -- {type(transform).__name__} ===")
logger.error(
f"\n=== Transform input info -- {type(transform).__name__} ==="
)
if isinstance(data, (list, tuple)):
data = data[0]

def _log_stats(data, prefix: str | None = "Data"):
if isinstance(data, (np.ndarray, torch.Tensor)):
# log data type, shape, range for array
datastats(img=data, data_shape=True, value_range=True, prefix=prefix)
datastats(
img=data, data_shape=True, value_range=True, prefix=prefix
)
else:
# log data type and value for other metadata
datastats(img=data, data_value=True, prefix=prefix)
Expand All @@ -194,7 +222,9 @@ class Randomizable(ThreadUnsafe, RandomizableTrait):

R: np.random.RandomState = np.random.RandomState()

def set_random_state(self, seed: int | None = None, state: np.random.RandomState | None = None) -> Randomizable:
def set_random_state(
self, seed: int | None = None, state: np.random.RandomState | None = None
) -> Randomizable:
"""
Set the random state locally, to control the randomness, the derived
classes should use :py:attr:`self.R` instead of `np.random` to introduce random
Expand All @@ -212,14 +242,20 @@ def set_random_state(self, seed: int | None = None, state: np.random.RandomState

"""
if seed is not None:
_seed = np.int64(id(seed) if not isinstance(seed, (int, np.integer)) else seed)
_seed = _seed % MAX_SEED # need to account for Numpy2.0 which doesn't silently convert to int64
_seed = np.int64(
id(seed) if not isinstance(seed, (int, np.integer)) else seed
)
_seed = (
_seed % MAX_SEED
) # need to account for Numpy2.0 which doesn't silently convert to int64
self.R = np.random.RandomState(_seed)
return self

if state is not None:
if not isinstance(state, np.random.RandomState):
raise TypeError(f"state must be None or a np.random.RandomState but is {type(state).__name__}.")
raise TypeError(
f"state must be None or a np.random.RandomState but is {type(state).__name__}."
)
self.R = state
return self

Expand All @@ -238,7 +274,9 @@ def randomize(self, data: Any) -> None:
Raises:
NotImplementedError: When the subclass does not override this method.
"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
raise NotImplementedError(
f"Subclass {self.__class__.__name__} must implement this method."
)


class Transform(ABC):
Expand Down Expand Up @@ -294,7 +332,9 @@ def __call__(self, data: Any):
NotImplementedError: When the subclass does not override this method.

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
raise NotImplementedError(
f"Subclass {self.__class__.__name__} must implement this method."
)


class LazyTransform(Transform, LazyTrait):
Expand Down Expand Up @@ -397,11 +437,15 @@ def __call__(self, data):
def __new__(cls, *args, **kwargs):
if config.USE_META_DICT:
# call_update after MapTransform.__call__
cls.__call__ = transforms.attach_hook(cls.__call__, MapTransform.call_update, "post") # type: ignore
cls.__call__ = transforms.attach_hook(
cls.__call__, MapTransform.call_update, "post"
) # type: ignore

if hasattr(cls, "inverse"):
# inverse_update before InvertibleTransform.inverse
cls.inverse: Any = transforms.attach_hook(cls.inverse, transforms.InvertibleTransform.inverse_update)
cls.inverse: Any = transforms.attach_hook(
cls.inverse, transforms.InvertibleTransform.inverse_update
)
return Transform.__new__(cls)

def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> None:
Expand All @@ -412,7 +456,9 @@ def __init__(self, keys: KeysCollection, allow_missing_keys: bool = False) -> No
raise ValueError("keys must be non empty.")
for key in self.keys:
if not isinstance(key, Hashable):
raise TypeError(f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}.")
raise TypeError(
f"keys must be one of (Hashable, Iterable[Hashable]) but is {type(keys).__name__}."
)

def call_update(self, data):
"""
Expand All @@ -432,7 +478,9 @@ def call_update(self, data):
for k in dict_i:
if not isinstance(dict_i[k], MetaTensor):
continue
list_d[idx] = transforms.sync_meta_info(k, dict_i, t=not isinstance(self, transforms.InvertD))
list_d[idx] = transforms.sync_meta_info(
k, dict_i, t=not isinstance(self, transforms.InvertD)
)
return list_d[0] if is_dict else list_d

@abstractmethod
Expand Down Expand Up @@ -460,9 +508,13 @@ def __call__(self, data):
An updated dictionary version of ``data`` by applying the transform.

"""
raise NotImplementedError(f"Subclass {self.__class__.__name__} must implement this method.")
raise NotImplementedError(
f"Subclass {self.__class__.__name__} must implement this method."
)

def key_iterator(self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None) -> Generator:
def key_iterator(
self, data: Mapping[Hashable, Any], *extra_iterables: Iterable | None
) -> Generator:
"""
Iterate across keys and optionally extra iterables. If key is missing, exception is raised if
`allow_missing_keys==False` (default). If `allow_missing_keys==True`, key is skipped.
Expand Down
Loading