Skip to content

Commit 3873d23

Browse files
authored
Make RandTorchVisiond and RandCuCIMd Consistent (#5569)
### Description Both RandTorchVisiond and RandCuCIMd are dealing with the third-party libraries and they should have the same API, especially on using `apply_prob` instead of `prob` since it is different than the `prob` of the underlying randomized tranform. Signed-off-by: Behrooz <3968947+drbeh@users.noreply.github.com> Signed-off-by: monai-bot <monai.miccai2019@gmail.com>
1 parent 75dffd7 commit 3873d23

File tree

11 files changed

+65
-133
lines changed

11 files changed

+65
-133
lines changed

monai/data/dataset.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from monai.transforms import (
3636
Compose,
3737
Randomizable,
38+
RandomizableTrait,
3839
ThreadUnsafe,
3940
Transform,
4041
apply_transform,
@@ -277,7 +278,7 @@ def set_transform_hash(self, hash_xform_func):
277278
inherit from MONAI's `Transform` class."""
278279
hashable_transforms = []
279280
for _tr in self.transform.flatten().transforms:
280-
if isinstance(_tr, Randomizable) or not isinstance(_tr, Transform):
281+
if isinstance(_tr, RandomizableTrait) or not isinstance(_tr, Transform):
281282
break
282283
hashable_transforms.append(_tr)
283284
# Try to hash. Fall back to a hash of their names
@@ -314,7 +315,7 @@ def _pre_transform(self, item_transformed):
314315
"""
315316
for _transform in self.transform.transforms:
316317
# execute all the deterministic transforms
317-
if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
318+
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
318319
break
319320
# this is to be consistent with CacheDataset even though it's not in a multi-thread situation.
320321
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
@@ -340,7 +341,7 @@ def _post_transform(self, item_transformed):
340341
for _transform in self.transform.transforms:
341342
if (
342343
start_post_randomize_run
343-
or isinstance(_transform, Randomizable)
344+
or isinstance(_transform, RandomizableTrait)
344345
or not isinstance(_transform, Transform)
345346
):
346347
start_post_randomize_run = True
@@ -877,7 +878,7 @@ def _load_cache_item(self, idx: int):
877878
item = self.data[idx]
878879
for _transform in self.transform.transforms: # type:ignore
879880
# execute all the deterministic transforms
880-
if isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
881+
if isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
881882
break
882883
_xform = deepcopy(_transform) if isinstance(_transform, ThreadUnsafe) else _transform
883884
item = apply_transform(_xform, item)
@@ -911,7 +912,7 @@ def _transform(self, index: int):
911912
if not isinstance(self.transform, Compose):
912913
raise ValueError("transform must be an instance of monai.transforms.Compose.")
913914
for _transform in self.transform.transforms:
914-
if start_run or isinstance(_transform, Randomizable) or not isinstance(_transform, Transform):
915+
if start_run or isinstance(_transform, RandomizableTrait) or not isinstance(_transform, Transform):
915916
# only need to deep copy data on first non-deterministic transform
916917
if not start_run:
917918
start_run = True

monai/data/test_time_augmentation.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,9 @@ class TestTimeAugmentation:
4848
"""
4949
Class for performing test time augmentations. This will pass the same image through the network multiple times.
5050
51-
The user passes transform(s) to be applied to each realisation, and provided that at least one of those transforms
51+
The user passes transform(s) to be applied to each realization, and provided that at least one of those transforms
5252
is random, the network's output will vary. Provided that inverse transformations exist for all supplied spatial
53-
transforms, the inverse can be applied to each realisation of the network's output. Once in the same spatial
53+
transforms, the inverse can be applied to each realization of the network's output. Once in the same spatial
5454
reference, the results can then be combined and metrics computed.
5555
5656
Test time augmentations are a useful feature for computing network uncertainty, as well as observing the network's
@@ -63,9 +63,10 @@ class TestTimeAugmentation:
6363
https://doi.org/10.1016/j.neucom.2019.01.103
6464
6565
Args:
66-
transform: transform (or composed) to be applied to each realisation. At least one transform must be of type
67-
`Randomizable`. All random transforms must be of type `InvertibleTransform`.
68-
batch_size: number of realisations to infer at once.
66+
transform: transform (or composed) to be applied to each realization. At least one transform must be of type
67+
`RandomizableTrait` (i.e. `Randomizable`, `RandomizableTransform`, or `RandomizableTrait`).
68+
. All random transforms must be of type `InvertibleTransform`.
69+
batch_size: number of realizations to infer at once.
6970
num_workers: how many subprocesses to use for data.
7071
inferrer_fn: function to use to perform inference.
7172
device: device on which to perform inference.
@@ -167,7 +168,7 @@ def __call__(
167168
"""
168169
Args:
169170
data: dictionary data to be processed.
170-
num_examples: number of realisations to be processed and results combined.
171+
num_examples: number of realizations to be processed and results combined.
171172
172173
Returns:
173174
- if `return_full_data==False`: mode, mean, std, vvc. The mode, mean and standard deviation are

monai/transforms/nvtx.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
Wrapper around NVIDIA Tools Extension for profiling MONAI transformations
1313
"""
1414

15-
from monai.transforms.transform import RandomizableTransform, Transform
15+
from monai.transforms.transform import RandomizableTrait, Transform
1616
from monai.utils import optional_import
1717

1818
_nvtx, _ = optional_import("torch._C._nvtx", descriptor="NVTX is not installed. Are you sure you have a CUDA build?")
@@ -63,9 +63,9 @@ def __call__(self, data):
6363
return data
6464

6565

66-
class RandRangePush(RangePush, RandomizableTransform):
66+
class RandRangePush(RangePush, RandomizableTrait):
6767
"""
68-
Pushes a range onto a stack of nested range span (RandomizableTransform).
68+
Pushes a range onto a stack of nested range span (for randomizable transforms).
6969
Stores zero-based depth of the range that is started.
7070
7171
Args:
@@ -84,9 +84,9 @@ def __call__(self, data):
8484
return data
8585

8686

87-
class RandRangePop(RangePop, RandomizableTransform):
87+
class RandRangePop(RangePop, RandomizableTrait):
8888
"""
89-
Pops a range off of a stack of nested range spans (RandomizableTransform).
89+
Pops a range off of a stack of nested range spans (for randomizable transforms).
9090
Stores zero-based depth of the range that is ended.
9191
"""
9292

@@ -107,10 +107,9 @@ def __call__(self, data):
107107
return data
108108

109109

110-
class RandMark(Mark, RandomizableTransform):
110+
class RandMark(Mark, RandomizableTrait):
111111
"""
112-
Mark an instantaneous event that occurred at some point.
113-
(RandomizableTransform)
112+
Mark an instantaneous event that occurred at some point (for randomizable transforms).
114113
115114
Args:
116115
msg: ASCII message to associate with the event.

monai/transforms/transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ class ThreadUnsafe:
192192
pass
193193

194194

195-
class Randomizable(ThreadUnsafe):
195+
class Randomizable(ThreadUnsafe, RandomizableTrait):
196196
"""
197197
An interface for handling random state locally, currently based on a class
198198
variable `R`, which is an instance of `np.random.RandomState`. This
@@ -332,7 +332,7 @@ def lazy_evaluation(self, lazy_evaluation: bool):
332332
self.lazy_evaluation = lazy_evaluation
333333

334334

335-
class RandomizableTransform(Randomizable, Transform, RandomizableTrait):
335+
class RandomizableTransform(Randomizable, Transform):
336336
"""
337337
An interface for handling random state locally, currently based on a class variable `R`,
338338
which is an instance of `np.random.RandomState`.

monai/transforms/utility/array.py

Lines changed: 5 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@
2929
from monai.data.meta_tensor import MetaTensor
3030
from monai.data.utils import no_collation
3131
from monai.transforms.inverse import InvertibleTransform
32-
from monai.transforms.transform import Randomizable, RandomizableTransform, Transform
32+
from monai.transforms.transform import Randomizable, RandomizableTrait, RandomizableTransform, Transform
3333
from monai.transforms.utils import (
3434
extreme_points_to_image,
3535
get_extreme_points,
@@ -1330,7 +1330,7 @@ def __call__(self, img: torch.Tensor):
13301330
class CuCIM(Transform):
13311331
"""
13321332
Wrap a non-randomized cuCIM transform, defined based on the transform name and args.
1333-
For randomized transforms (or randomly applying a transform) use :py:class:`monai.transforms.RandCuCIM`.
1333+
For randomized transforms use :py:class:`monai.transforms.RandCuCIM`.
13341334
13351335
Args:
13361336
name: the transform name in CuCIM package
@@ -1361,46 +1361,25 @@ def __call__(self, data):
13611361
return self.transform(data, *self.args, **self.kwargs)
13621362

13631363

1364-
class RandCuCIM(CuCIM, RandomizableTransform):
1364+
class RandCuCIM(CuCIM, RandomizableTrait):
13651365
"""
1366-
Wrap a randomized cuCIM transform, defined based on the transform name and args,
1367-
or randomly apply a non-randomized transform.
1366+
Wrap a randomized cuCIM transform, defined based on the transform name and args
13681367
For deterministic non-randomized transforms use :py:class:`monai.transforms.CuCIM`.
13691368
13701369
Args:
13711370
name: the transform name in CuCIM package.
1372-
apply_prob: the probability to apply the transform (default=1.0)
13731371
args: parameters for the CuCIM transform.
13741372
kwargs: parameters for the CuCIM transform.
13751373
13761374
Note:
13771375
- CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`.
13781376
Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array.
1379-
- If the cuCIM transform is already randomized the `apply_prob` argument has nothing to do with
1380-
the randomness of the underlying cuCIM transform. `apply_prob` defines if the transform (either randomized
1381-
or non-randomized) being applied randomly, so it can apply non-randomized transforms randomly but be careful
1382-
with setting `apply_prob` to anything than 1.0 when using along with cuCIM's randomized transforms.
13831377
- If the random factor of the underlying cuCIM transform is not derived from `self.R`,
13841378
the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`.
13851379
"""
13861380

1387-
def __init__(self, name: str, apply_prob: float = 1.0, *args, **kwargs) -> None:
1381+
def __init__(self, name: str, *args, **kwargs) -> None:
13881382
CuCIM.__init__(self, name, *args, **kwargs)
1389-
RandomizableTransform.__init__(self, prob=apply_prob)
1390-
1391-
def __call__(self, data):
1392-
"""
1393-
Args:
1394-
data: a CuPy array (`cupy.ndarray`) for the cuCIM transform
1395-
1396-
Returns:
1397-
`cupy.ndarray`
1398-
1399-
"""
1400-
self.randomize(data)
1401-
if not self._do_transform:
1402-
return data
1403-
return super().__call__(data)
14041383

14051384

14061385
class AddCoordinateChannels(Transform):

monai/transforms/utility/dictionary.py

Lines changed: 26 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
from monai.data.meta_tensor import MetaObj, MetaTensor
2828
from monai.data.utils import no_collation
2929
from monai.transforms.inverse import InvertibleTransform
30-
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTransform
30+
from monai.transforms.transform import MapTransform, Randomizable, RandomizableTrait, RandomizableTransform
3131
from monai.transforms.utility.array import (
3232
AddChannel,
3333
AddCoordinateChannels,
@@ -1457,51 +1457,38 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
14571457
return d
14581458

14591459

1460-
class RandTorchVisiond(RandomizableTransform, MapTransform):
1460+
class RandTorchVisiond(MapTransform, RandomizableTrait):
14611461
"""
14621462
Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision` for randomized transforms.
14631463
For deterministic non-randomized transforms of TorchVision use :py:class:`monai.transforms.TorchVisiond`.
14641464
1465+
Args:
1466+
keys: keys of the corresponding items to be transformed.
1467+
See also: :py:class:`monai.transforms.compose.MapTransform`
1468+
name: The transform name in TorchVision package.
1469+
allow_missing_keys: don't raise exception if key is missing.
1470+
args: parameters for the TorchVision transform.
1471+
kwargs: parameters for the TorchVision transform.
1472+
14651473
Note:
14661474
14671475
- As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
1468-
data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor.
1469-
- This class inherits the ``RandomizableTransform`` purely to prevent any dataset caching to skip the transform
1476+
data to be dict of PyTorch Tensors. Users should call `ToTensord` transform first to convert Numpy to Tensor.
1477+
- This class inherits the ``Randomizable`` purely to prevent any dataset caching to skip the transform
14701478
computation. If the random factor of the underlying torchvision transform is not derived from `self.R`,
1471-
the results may not be deterministic. It also provides the probability to apply this transform.
1472-
See Also: :py:class:`monai.transforms.RandomizableTransform`.
1479+
the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`.
14731480
14741481
"""
14751482

14761483
backend = TorchVision.backend
14771484

1478-
def __init__(
1479-
self, keys: KeysCollection, name: str, prob: float = 1.0, allow_missing_keys: bool = False, *args, **kwargs
1480-
) -> None:
1481-
"""
1482-
Args:
1483-
keys: keys of the corresponding items to be transformed.
1484-
See also: :py:class:`monai.transforms.compose.MapTransform`
1485-
name: The transform name in TorchVision package.
1486-
prob: Probability of applying this transform.
1487-
allow_missing_keys: don't raise exception if key is missing.
1488-
args: parameters for the TorchVision transform.
1489-
kwargs: parameters for the TorchVision transform.
1490-
1491-
"""
1492-
RandomizableTransform.__init__(self, prob=prob)
1485+
def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
14931486
MapTransform.__init__(self, keys, allow_missing_keys)
1494-
14951487
self.name = name
14961488
self.trans = TorchVision(name, *args, **kwargs)
14971489

14981490
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
14991491
d = dict(data)
1500-
1501-
self.randomize(data)
1502-
if not self._do_transform:
1503-
return d
1504-
15051492
for key in self.key_iterator(d):
15061493
d[key] = self.trans(d[key])
15071494
return d
@@ -1677,7 +1664,7 @@ def __call__(self, data):
16771664
return d
16781665

16791666

1680-
class RandCuCIMd(CuCIMd, RandomizableTransform):
1667+
class RandCuCIMd(MapTransform, RandomizableTrait):
16811668
"""
16821669
Dictionary-based wrapper of :py:class:`monai.transforms.CuCIM` for randomized transforms.
16831670
For deterministic non-randomized transforms of CuCIM use :py:class:`monai.transforms.CuCIMd`.
@@ -1686,25 +1673,22 @@ class RandCuCIMd(CuCIMd, RandomizableTransform):
16861673
keys: keys of the corresponding items to be transformed.
16871674
See also: :py:class:`monai.transforms.compose.MapTransform`
16881675
name: The transform name in CuCIM package.
1689-
apply_prob: the probability to apply the transform (default=1.0)
16901676
allow_missing_keys: don't raise exception if key is missing.
16911677
args: parameters for the CuCIM transform.
16921678
kwargs: parameters for the CuCIM transform.
16931679
16941680
Note:
16951681
- CuCIM transform only work with CuPy arrays, so this transform expects input data to be `cupy.ndarray`.
1696-
Users can call `ToCuPy` transform to convert a numpy array or torch tensor to cupy array.
1697-
- If the cuCIM transform is already randomized the `apply_prob` argument has nothing to do with
1698-
the randomness of the underlying cuCIM transform. `apply_prob` defines if the transform (either randomized
1699-
or non-randomized) being applied randomly, so it can apply non-randomized transforms randomly but be careful
1700-
with setting `apply_prob` to anything than 1.0 when using along with cuCIM's randomized transforms.
1701-
- If the random factor of the underlying cuCIM transform is not derived from `self.R`,
1682+
Users should call `ToCuPy` transform first to convert a numpy array or torch tensor to cupy array.
1683+
- This class inherits the ``Randomizable`` purely to prevent any dataset caching to skip the transform
1684+
computation. If the random factor of the underlying cuCIM transform is not derived from `self.R`,
17021685
the results may not be deterministic. See Also: :py:class:`monai.transforms.Randomizable`.
17031686
"""
17041687

1705-
def __init__(self, apply_prob: float = 1.0, *args, **kwargs) -> None:
1706-
CuCIMd.__init__(self, *args, **kwargs)
1707-
RandomizableTransform.__init__(self, prob=apply_prob)
1688+
def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1689+
MapTransform.__init__(self, keys, allow_missing_keys)
1690+
self.name = name
1691+
self.trans = CuCIM(name, *args, **kwargs)
17081692

17091693
def __call__(self, data):
17101694
"""
@@ -1715,10 +1699,10 @@ def __call__(self, data):
17151699
Dict[Hashable, `cupy.ndarray`]
17161700
17171701
"""
1718-
self.randomize(data)
1719-
if not self._do_transform:
1720-
return dict(data)
1721-
return super().__call__(data)
1702+
d = dict(data)
1703+
for key in self.key_iterator(d):
1704+
d[key] = self.trans(d[key])
1705+
return d
17221706

17231707

17241708
class AddCoordinateChannelsd(MapTransform):

tests/test_nvtx_decorator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
RandAdjustContrast,
2525
RandCuCIM,
2626
RandFlip,
27-
Randomizable,
27+
RandomizableTrait,
2828
Rotate90,
2929
ToCupy,
3030
ToNumpy,
@@ -217,7 +217,7 @@ def test_tranform_randomized(self, input):
217217

218218
# Check if the first randomized is RandAdjustContrast
219219
for tran in transforms.transforms:
220-
if isinstance(tran, Randomizable):
220+
if isinstance(tran, RandomizableTrait):
221221
self.assertIsInstance(tran, RandAdjustContrast)
222222
break
223223

tests/test_nvtx_transform.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import torch
1616
from parameterized import parameterized
1717

18-
from monai.transforms import Compose, Flip, RandFlip, RandFlipD, Randomizable, ToTensor, ToTensorD
18+
from monai.transforms import Compose, Flip, RandFlip, RandFlipD, RandomizableTrait, ToTensor, ToTensorD
1919
from monai.transforms.nvtx import (
2020
Mark,
2121
MarkD,
@@ -59,7 +59,7 @@ def test_nvtx_transfroms_alone(self, input):
5959

6060
# Check if chain of randomizable/non-randomizable transforms is not broken
6161
for tran in transforms.transforms:
62-
if isinstance(tran, Randomizable):
62+
if isinstance(tran, RandomizableTrait):
6363
self.assertIsInstance(tran, RangePush)
6464
break
6565

0 commit comments

Comments
 (0)