Skip to content

Commit 68b78d2

Browse files
Support RandTorchVisiond as RandomizableTransform (#5567)
Signed-off-by: Sachidanand Alle <sachidanand.alle@gmail.com> Fixes # . Instead of simple Randomizable transform, extend RandTorchVisiond into RandomizableTransform ### Description This will help users to choose to apply the transform with some probability. For example, we want to mix training samples both with and without ColorJitter (as during validation we don't use ColorJitter) ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. Signed-off-by: Sachidanand Alle <sachidanand.alle@gmail.com>
1 parent 9adfb4b commit 68b78d2

File tree

1 file changed

+15
-5
lines changed

1 file changed

+15
-5
lines changed

monai/transforms/utility/dictionary.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1405,7 +1405,7 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, N
14051405
return d
14061406

14071407

1408-
class RandTorchVisiond(Randomizable, MapTransform):
1408+
class RandTorchVisiond(RandomizableTransform, MapTransform):
14091409
"""
14101410
Dictionary-based wrapper of :py:class:`monai.transforms.TorchVision` for randomized transforms.
14111411
For deterministic non-randomized transforms of TorchVision use :py:class:`monai.transforms.TorchVisiond`.
@@ -1414,32 +1414,42 @@ class RandTorchVisiond(Randomizable, MapTransform):
14141414
14151415
- As most of the TorchVision transforms only work for PIL image and PyTorch Tensor, this transform expects input
14161416
data to be dict of PyTorch Tensors, users can easily call `ToTensord` transform to convert Numpy to Tensor.
1417-
- This class inherits the ``Randomizable`` purely to prevent any dataset caching to skip the transform
1417+
- This class inherits the ``RandomizableTransform`` purely to prevent any dataset caching to skip the transform
14181418
computation. If the random factor of the underlying torchvision transform is not derived from `self.R`,
1419-
the results may not be deterministic.
1420-
See Also: :py:class:`monai.transforms.Randomizable`.
1419+
the results may not be deterministic. It also provides the probability to apply this transform.
1420+
See Also: :py:class:`monai.transforms.RandomizableTransform`.
14211421
14221422
"""
14231423

14241424
backend = TorchVision.backend
14251425

1426-
def __init__(self, keys: KeysCollection, name: str, allow_missing_keys: bool = False, *args, **kwargs) -> None:
1426+
def __init__(
1427+
self, keys: KeysCollection, name: str, prob: float = 1.0, allow_missing_keys: bool = False, *args, **kwargs
1428+
) -> None:
14271429
"""
14281430
Args:
14291431
keys: keys of the corresponding items to be transformed.
14301432
See also: :py:class:`monai.transforms.compose.MapTransform`
14311433
name: The transform name in TorchVision package.
1434+
prob: Probability of applying this transform.
14321435
allow_missing_keys: don't raise exception if key is missing.
14331436
args: parameters for the TorchVision transform.
14341437
kwargs: parameters for the TorchVision transform.
14351438
14361439
"""
1440+
RandomizableTransform.__init__(self, prob=prob)
14371441
MapTransform.__init__(self, keys, allow_missing_keys)
1442+
14381443
self.name = name
14391444
self.trans = TorchVision(name, *args, **kwargs)
14401445

14411446
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> Dict[Hashable, NdarrayOrTensor]:
14421447
d = dict(data)
1448+
1449+
self.randomize(data)
1450+
if not self._do_transform:
1451+
return d
1452+
14431453
for key in self.key_iterator(d):
14441454
d[key] = self.trans(d[key])
14451455
return d

0 commit comments

Comments
 (0)