Skip to content

Commit 5e58297

Browse files
committed
Merge remote-tracking branch 'upstream/dev' into 4855-lazy-resampling-impl
Signed-off-by: Wenqi Li <wenqil@nvidia.com>
2 parents 7518371 + e279463 commit 5e58297

File tree

3 files changed

+64
-1
lines changed

3 files changed

+64
-1
lines changed

monai/transforms/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -486,6 +486,7 @@
486486
Lambda,
487487
MapLabelValue,
488488
RandCuCIM,
489+
RandIdentity,
489490
RandImageFilter,
490491
RandLambda,
491492
RemoveRepeatedChannel,

monai/transforms/utility/array.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from collections.abc import Mapping, Sequence
2323
from copy import deepcopy
2424
from functools import partial
25-
from typing import Callable
25+
from typing import Any, Callable
2626

2727
import numpy as np
2828
import torch
@@ -75,6 +75,7 @@
7575

7676
__all__ = [
7777
"Identity",
78+
"RandIdentity",
7879
"AsChannelFirst",
7980
"AsChannelLast",
8081
"AddChannel",
@@ -128,6 +129,18 @@ def __call__(self, img: NdarrayOrTensor) -> NdarrayOrTensor:
128129
return img
129130

130131

132+
class RandIdentity(RandomizableTrait):
133+
"""
134+
Do nothing to the data. This transform is random, so can be used to stop the caching of any
135+
subsequent transforms.
136+
"""
137+
138+
backend = [TransformBackends.TORCH, TransformBackends.NUMPY]
139+
140+
def __call__(self, data: Any) -> Any:
141+
return data
142+
143+
131144
@deprecated(since="0.8", msg_suffix="please use MetaTensor data type and monai.transforms.EnsureChannelFirst instead.")
132145
class AsChannelFirst(Transform):
133146
"""

tests/test_randidentity.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import monai.transforms as mt
17+
from monai.data import CacheDataset
18+
from tests.utils import TEST_NDARRAYS, NumpyImageTestCase2D, assert_allclose
19+
20+
21+
class T(mt.Transform):
22+
def __call__(self, x):
23+
return x * 2
24+
25+
26+
class TestIdentity(NumpyImageTestCase2D):
27+
def test_identity(self):
28+
for p in TEST_NDARRAYS:
29+
img = p(self.imt)
30+
identity = mt.RandIdentity()
31+
assert_allclose(img, identity(img))
32+
33+
def test_caching(self, init=1, expect=4, expect_pre_cache=2):
34+
# check that we get the correct result (two lots of T so should get 4)
35+
x = init
36+
transforms = mt.Compose([T(), mt.RandIdentity(), T()])
37+
self.assertEqual(transforms(x), expect)
38+
39+
# check we get correct result with CacheDataset
40+
x = [init]
41+
ds = CacheDataset(x, transforms)
42+
self.assertEqual(ds[0], expect)
43+
44+
# check that the cached value is correct
45+
self.assertEqual(ds._cache[0], expect_pre_cache)
46+
47+
48+
if __name__ == "__main__":
49+
unittest.main()

0 commit comments

Comments
 (0)