Skip to content
Merged
Show file tree
Hide file tree
Changes from 13 commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
1759f96
Initial commit to resolve #6223
atbenmurray Mar 22, 2023
ba4115d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Mar 22, 2023
632119e
[MONAI] code formatting
monai-bot Mar 22, 2023
50a9d06
Initial commit to resolve #6223
atbenmurray Mar 22, 2023
dd5ad30
Merging into autoformat fixes
atbenmurray Mar 22, 2023
10edabd
DCO Remediation Commit for Ben Murray <ben.murray@gmail.com>
atbenmurray Mar 22, 2023
5a8ccaf
DCO Remediation Commit for Ben Murray <ben.murray@gmail.com>
atbenmurray Mar 22, 2023
918cbf6
Merge branch 'compose_refactor' of github.com:project-monai/monai int…
atbenmurray Mar 22, 2023
c3bd8c0
DCO Remediation Commit for Ben Murray <ben.murray@gmail.com>
atbenmurray Mar 22, 2023
0c0baa4
Fixes to make test_cachedataset, test_persistentdataset and test_cach…
atbenmurray Mar 22, 2023
c5a73f6
Documentation for Compose.execute
atbenmurray Mar 22, 2023
41a156a
style/docs
wyli Mar 22, 2023
eda9b4a
Merge remote-tracking branch 'upstream/dev' into compose_refactor
wyli Mar 22, 2023
61f67e4
Added tests; updated documentation
atbenmurray Mar 24, 2023
bd591b7
Merge branch 'compose_refactor' of github.com:project-monai/monai int…
atbenmurray Mar 24, 2023
7f37392
DCO Remediation Commit for Ben Murray <ben.murray@gmail.com>
atbenmurray Mar 24, 2023
5301af1
Honoring the self.copy_cache flag
atbenmurray Mar 24, 2023
5bab340
Merge branch 'dev' into compose_refactor
atbenmurray Mar 24, 2023
6bdedac
Updating for lazy resampling
atbenmurray Mar 24, 2023
d99b2b9
Autoformatting
atbenmurray Mar 24, 2023
0223e62
Merge branch 'dev' into compose_refactor
wyli Mar 27, 2023
d939395
Moving Compose.execute to execute_compose as per @ericspod's request.…
atbenmurray Mar 29, 2023
4ecf2c3
Test fix: missed Compose.execute to execute_compose changes
atbenmurray Mar 29, 2023
948444f
DCO Remediation Commit for Ben Murray <ben.murray@gmail.com>
atbenmurray Mar 29, 2023
9097e07
Bug fix for SomeOff; generate list of transforms in execution order
atbenmurray Mar 29, 2023
d701642
Documentation for Compose.get_index_of_first
atbenmurray Mar 29, 2023
77e465d
Slight documentation reformatting for Compose.get_index_of_first
atbenmurray Mar 29, 2023
5605e71
Updated docstrings for execute_compose. Renamed input_ to data for
atbenmurray Mar 29, 2023
c06e014
Fixing errors reported by flake8-py3 (mypy) output
atbenmurray Mar 29, 2023
191506d
Had to go back to lazy_evaluation default of None for now but this is a
atbenmurray Mar 29, 2023
ec71402
execute_compose type ignore as it can't be fixed without polluting more
atbenmurray Mar 29, 2023
85b18de
type: ignore suppression as this is being addressed separately
atbenmurray Mar 29, 2023
7684edf
Merge branch 'dev' into compose_refactor
wyli Mar 30, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 33 additions & 41 deletions monai/data/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
Compose,
Randomizable,
RandomizableTrait,
ThreadUnsafe,
Transform,
apply_transform,
convert_to_contiguous,
Expand Down Expand Up @@ -316,13 +315,15 @@ def _pre_transform(self, item_transformed):
random transform object

"""
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
break
# this is to be consistent with CacheDataset even though it's not in a multi-thread situation.
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = apply_transform(_xform, item_transformed)
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)

item_transformed = self.transform(item_transformed, end=first_random, threading=True)

if self.reset_ops_id:
reset_ops_id(item_transformed)
return item_transformed
Expand All @@ -340,15 +341,12 @@ def _post_transform(self, item_transformed):
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
start_post_randomize_run = False
for _transform in self.transform.transforms:
if (
start_post_randomize_run
or isinstance(_transform, RandomizableTrait)
or not isinstance(_transform, Transform)
):
start_post_randomize_run = True
item_transformed = apply_transform(_transform, item_transformed)

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
if first_random is not None:
item_transformed = self.transform(item_transformed, start=first_random)
return item_transformed

def _cachecheck(self, item_transformed):
Expand Down Expand Up @@ -492,11 +490,9 @@ def _pre_transform(self, item_transformed):
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for i, _transform in enumerate(self.transform.transforms):
if i == self.cache_n_trans:
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item_transformed = apply_transform(_xform, item_transformed)

item_transformed = self.transform(item_transformed, end=self.cache_n_trans, threading=True)

reset_ops_id(item_transformed)
return item_transformed

Expand All @@ -512,10 +508,8 @@ def _post_transform(self, item_transformed):
"""
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for i, _transform in enumerate(self.transform.transforms):
if i >= self.cache_n_trans:
item_transformed = apply_transform(_transform, item_transformed)
return item_transformed

return self.transform(item_transformed, start=self.cache_n_trans)


class LMDBDataset(PersistentDataset):
Expand Down Expand Up @@ -879,12 +873,12 @@ def _load_cache_item(self, idx: int):
idx: the index of the input data sequence.
"""
item = self.data[idx]
for _transform in self.transform.transforms:
# execute all the deterministic transforms
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
break
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
item = apply_transform(_xform, item)

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
item = self.transform(item, end=first_random, threading=True)

if self.as_contiguous:
item = convert_to_contiguous(item, memory_format=torch.contiguous_format)
return item
Expand All @@ -911,17 +905,15 @@ def _transform(self, index: int):
data = self._cache[cache_index] = self._load_cache_item(cache_index)

# load data from cache and execute from the first random transform
start_run = False
if not isinstance(self.transform, Compose):
raise ValueError("transform must be an instance of monai.transforms.Compose.")
for _transform in self.transform.transforms:
if start_run or isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
# only need to deep copy data on first non-deterministic transform
if not start_run:
start_run = True
if self.copy_cache:
data = deepcopy(data)
data = apply_transform(_transform, data)

first_random = self.transform.get_index_of_first(
lambda t: isinstance(t, RandomizableTrait) or not isinstance(t, Transform)
)
if first_random is not None:
data = self.transform(deepcopy(data), start=first_random)

return data


Expand Down
102 changes: 94 additions & 8 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@

import warnings
from collections.abc import Callable, Mapping, Sequence
from copy import deepcopy
from typing import Any

import numpy as np

import monai
from monai.config import NdarrayOrTensor
from monai.transforms.inverse import InvertibleTransform
from monai.transforms.traits import ThreadUnsafe

# For backwards compatibility (so this still works: from monai.transforms.compose import MapTransform)
from monai.transforms.transform import ( # noqa: F401
Expand Down Expand Up @@ -152,6 +155,12 @@ def randomize(self, data: Any | None = None) -> None:
f'Transform "{tfm_name}" in Compose not randomized\n{tfm_name}.{type_error}.', RuntimeWarning
)

def get_index_of_first(self, predicate):
for i in range(len(self.transforms)):
if predicate(self.transforms[i]):
return i
return None

def flatten(self):
"""Return a Composition with a simple list of transforms, as opposed to any nested Compositions.

Expand All @@ -172,11 +181,69 @@ def __len__(self):
"""Return number of transformations."""
return len(self.flatten().transforms)

def __call__(self, input_):
for _transform in self.transforms:
input_ = apply_transform(_transform, input_, self.map_items, self.unpack_items, self.log_stats)
@classmethod
def execute(
cls,
input_: NdarrayOrTensor,
transforms: Sequence[Any],
map_items: bool = True,
unpack_items: bool = False,
log_stats: bool = False,
start: int = 0,
end: int | None = None,
threading: bool = False,
) -> NdarrayOrTensor:
"""
``execute`` provides the implementation that Compose uses to execute a sequence
of transforms. As well as being used by Compose, it can be used by subclasses of
Compose and by code that doesn't have a Compose instance but needs to execute a
sequence of transforms is if it were executed by Compose. For the most part, it
is recommended to use Compose instances, however.
Args:
`input_`: a tensor-like object to be transformed
transforms: a sequence of transforms to be carried out
map_items: whether to apply the transform to each item in ``data```.
Defaults to True if not set.
unpack_items: whether to unpack parameters using '*'. Defaults to False if not set
log_stats: whether to log detailed information about the application of ``transforms``
to ``input_``. For NumPy ndarrays and PyTorch tensors, log only the data shape and
value range. Defaults to False if not set.
start:
end:
threading:

Returns:

"""
end_ = len(transforms) if end is None else end
if start is None:
raise ValueError(f"'start' ({start}) cannot be None")
if start > end_:
raise ValueError(f"'start' ({start}) must be less than 'end' ({end_})")
if end_ > len(transforms):
raise ValueError(f"'end' ({end_}) must be less than or equal to the transform count ({len(transforms)}")

# no-op if the range is empty
if start == end:
return input_

for _transform in transforms[start:end]:
if threading:
_transform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
input_ = apply_transform(_transform, input_, map_items, unpack_items, log_stats) # type: ignore
return input_

def __call__(self, input_, start=0, end=None, threading=False):
return Compose.execute(
input_,
self.transforms,
map_items=self.map_items,
unpack_items=self.unpack_items,
start=start,
end=end,
threading=threading,
)

def inverse(self, data):
invertible_transforms = [t for t in self.flatten().transforms if isinstance(t, InvertibleTransform)]
if not invertible_transforms:
Expand Down Expand Up @@ -254,12 +321,23 @@ def flatten(self):
weights.append(w)
return OneOf(transforms, weights, self.map_items, self.unpack_items)

def __call__(self, data):
def __call__(self, data, start=0, end=None, threading=False):
if len(self.transforms) == 0:
return data

index = self.R.multinomial(1, self.weights).argmax()
_transform = self.transforms[index]
data = apply_transform(_transform, data, self.map_items, self.unpack_items, self.log_stats)

data = Compose.execute(
data,
[_transform],
map_items=self.map_items,
unpack_items=self.unpack_items,
start=start,
end=end,
threading=threading,
)

# if the data is a mapping (dictionary), append the OneOf transform to the end
if isinstance(data, monai.data.MetaTensor):
self.push_transform(data, extra_info={"index": index})
Expand Down Expand Up @@ -318,14 +396,22 @@ def __init__(
) -> None:
super().__init__(transforms, map_items, unpack_items, log_stats)

def __call__(self, input_):
def __call__(self, input_, start=0, end=None, threading=False):
if len(self.transforms) == 0:
return input_
num = len(self.transforms)
applied_order = self.R.permutation(range(num))

for index in applied_order:
input_ = apply_transform(self.transforms[index], input_, self.map_items, self.unpack_items, self.log_stats)
input_ = Compose.execute(
input_,
[self.transforms[ind] for ind in applied_order],
map_items=self.map_items,
unpack_items=self.unpack_items,
start=start,
end=end,
threading=threading,
)

# if the data is a mapping (dictionary), append the RandomOrder transform to the end
if isinstance(input_, monai.data.MetaTensor):
self.push_transform(input_, extra_info={"applied_order": applied_order})
Expand Down