Skip to content

Commit d6bafc9

Browse files
authored
Added RandSimulateLowResolution(d) array and dictionary transforms and corresponding unit tests (#6806)
Fixes #3781. ### Description Random simulation of low resolution corresponding to nnU-Net's (https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23). First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled from the `zoom_range`. Then, the array/tensor is resampled at the original resolution. MONAI's `Resize` transform is used for the resampling operations. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [x] In-line docstrings updated. - [x] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Aaron Kujawa <askujawa@gmail.com>
1 parent 4c22a27 commit d6bafc9

File tree

6 files changed

+360
-1
lines changed

6 files changed

+360
-1
lines changed

docs/source/transforms.rst

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -925,6 +925,12 @@ Spatial
925925
:members:
926926
:special-members: __call__
927927

928+
`RandSimulateLowResolution`
929+
"""""""""""""""""""""""""""
930+
.. autoclass:: RandSimulateLowResolution
931+
:members:
932+
:special-members: __call__
933+
928934

929935
Smooth Field
930936
^^^^^^^^^^^^
@@ -1886,6 +1892,13 @@ Spatial (Dict)
18861892
:members:
18871893
:special-members: __call__
18881894

1895+
`RandSimulateLowResolutiond`
1896+
""""""""""""""""""""""""""""
1897+
.. autoclass:: RandSimulateLowResolutiond
1898+
:members:
1899+
:special-members: __call__
1900+
1901+
18891902
Smooth Field (Dict)
18901903
^^^^^^^^^^^^^^^^^^^
18911904

monai/transforms/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@
381381
RandGridPatch,
382382
RandRotate,
383383
RandRotate90,
384+
RandSimulateLowResolution,
384385
RandZoom,
385386
Resample,
386387
ResampleToMatch,
@@ -437,6 +438,9 @@
437438
RandRotated,
438439
RandRotateD,
439440
RandRotateDict,
441+
RandSimulateLowResolutiond,
442+
RandSimulateLowResolutionD,
443+
RandSimulateLowResolutionDict,
440444
RandZoomd,
441445
RandZoomD,
442446
RandZoomDict,

monai/transforms/spatial/array.py

Lines changed: 94 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
from monai.config import USE_COMPILED, DtypeLike
2727
from monai.config.type_definitions import NdarrayOrTensor
28-
from monai.data.meta_obj import get_track_meta
28+
from monai.data.meta_obj import get_track_meta, set_track_meta
2929
from monai.data.meta_tensor import MetaTensor
3030
from monai.data.utils import AFFINE_TOL, affine_to_spacing, compute_shape_offset, iter_patch, to_affine_nd, zoom_affine
3131
from monai.networks.layers import AffineTransform, GaussianFilter, grid_pull
@@ -111,6 +111,7 @@
111111
"RandAffine",
112112
"Rand2DElastic",
113113
"Rand3DElastic",
114+
"RandSimulateLowResolution",
114115
]
115116

116117
RandRange = Optional[Union[Sequence[Union[Tuple[float, float], float]], float]]
@@ -3456,3 +3457,95 @@ def __call__(self, array: NdarrayOrTensor, randomize: bool = True):
34563457
if randomize:
34573458
self.randomize(array)
34583459
return super().__call__(array)
3460+
3461+
3462+
class RandSimulateLowResolution(RandomizableTransform):
3463+
"""
3464+
Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform
3465+
(https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)
3466+
First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled
3467+
from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.
3468+
"""
3469+
3470+
backend = Affine.backend
3471+
3472+
def __init__(
3473+
self,
3474+
prob: float = 0.1,
3475+
downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,
3476+
upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,
3477+
zoom_range: Sequence[float] = (0.5, 1.0),
3478+
align_corners=False,
3479+
device: torch.device | None = None,
3480+
) -> None:
3481+
"""
3482+
Args:
3483+
prob: probability of performing this augmentation
3484+
downsample_mode: interpolation mode for downsampling operation
3485+
upsample_mode: interpolation mode for upsampling operation
3486+
zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is
3487+
sampled. It determines the shape of the downsampled tensor.
3488+
align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear',
3489+
'bicubic' or 'trilinear'. Default: False
3490+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
3491+
device: device on which the tensor will be allocated.
3492+
3493+
"""
3494+
RandomizableTransform.__init__(self, prob)
3495+
3496+
self.downsample_mode = downsample_mode
3497+
self.upsample_mode = upsample_mode
3498+
self.zoom_range = zoom_range
3499+
self.align_corners = align_corners
3500+
self.device = device
3501+
self.zoom_factor = 1.0
3502+
3503+
def randomize(self, data: Any | None = None) -> None:
3504+
super().randomize(None)
3505+
self.zoom_factor = self.R.uniform(self.zoom_range[0], self.zoom_range[1])
3506+
if not self._do_transform:
3507+
return None
3508+
3509+
def __call__(self, img: torch.Tensor, randomize: bool = True) -> torch.Tensor:
3510+
"""
3511+
Args:
3512+
img: shape must be (num_channels, H, W[, D]),
3513+
randomize: whether to execute `randomize()` function first, defaults to True.
3514+
"""
3515+
if randomize:
3516+
self.randomize()
3517+
3518+
if self._do_transform:
3519+
input_shape = img.shape[1:]
3520+
target_shape = np.round(np.array(input_shape) * self.zoom_factor).astype(np.int_)
3521+
3522+
resize_tfm_downsample = Resize(
3523+
spatial_size=target_shape, size_mode="all", mode=self.downsample_mode, anti_aliasing=False
3524+
)
3525+
3526+
resize_tfm_upsample = Resize(
3527+
spatial_size=input_shape,
3528+
size_mode="all",
3529+
mode=self.upsample_mode,
3530+
anti_aliasing=False,
3531+
align_corners=self.align_corners,
3532+
)
3533+
# temporarily disable metadata tracking, since we do not want to invert the two Resize functions during
3534+
# post-processing
3535+
original_tack_meta_value = get_track_meta()
3536+
set_track_meta(False)
3537+
3538+
img_downsampled = resize_tfm_downsample(img)
3539+
img_upsampled = resize_tfm_upsample(img_downsampled)
3540+
3541+
# reset metadata tracking to original value
3542+
set_track_meta(original_tack_meta_value)
3543+
3544+
# copy metadata from original image to down-and-upsampled image
3545+
img_upsampled = MetaTensor(img_upsampled)
3546+
img_upsampled.copy_meta_from(img)
3547+
3548+
return img_upsampled
3549+
3550+
else:
3551+
return img

monai/transforms/spatial/dictionary.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
RandGridDistortion,
4646
RandGridPatch,
4747
RandRotate,
48+
RandSimulateLowResolution,
4849
RandZoom,
4950
ResampleToMatch,
5051
Resize,
@@ -140,6 +141,9 @@
140141
"RandGridPatchd",
141142
"RandGridPatchD",
142143
"RandGridPatchDict",
144+
"RandSimulateLowResolutiond",
145+
"RandSimulateLowResolutionD",
146+
"RandSimulateLowResolutionDict",
143147
]
144148

145149

@@ -2518,6 +2522,94 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
25182522
return d
25192523

25202524

2525+
class RandSimulateLowResolutiond(RandomizableTransform, MapTransform):
2526+
"""
2527+
Dictionary-based wrapper of :py:class:`monai.transforms.RandSimulateLowResolution`.
2528+
Random simulation of low resolution corresponding to nnU-Net's SimulateLowResolutionTransform
2529+
(https://github.com/MIC-DKFZ/batchgenerators/blob/7651ece69faf55263dd582a9f5cbd149ed9c3ad0/batchgenerators/transforms/resample_transforms.py#L23)
2530+
First, the array/tensor is resampled at lower resolution as determined by the zoom_factor which is uniformly sampled
2531+
from the `zoom_range`. Then, the array/tensor is resampled at the original resolution.
2532+
"""
2533+
2534+
backend = RandAffine.backend
2535+
2536+
def __init__(
2537+
self,
2538+
keys: KeysCollection,
2539+
prob: float = 0.1,
2540+
downsample_mode: InterpolateMode | str = InterpolateMode.NEAREST,
2541+
upsample_mode: InterpolateMode | str = InterpolateMode.TRILINEAR,
2542+
zoom_range=(0.5, 1.0),
2543+
align_corners=False,
2544+
allow_missing_keys: bool = False,
2545+
device: torch.device | None = None,
2546+
) -> None:
2547+
"""
2548+
Args:
2549+
keys: keys of the corresponding items to be transformed.
2550+
prob: probability of performing this augmentation
2551+
downsample_mode: interpolation mode for downsampling operation
2552+
upsample_mode: interpolation mode for upsampling operation
2553+
zoom_range: range from which the random zoom factor for the downsampling and upsampling operation is
2554+
sampled. It determines the shape of the downsampled tensor.
2555+
align_corners: This only has an effect when downsample_mode or upsample_mode is 'linear', 'bilinear',
2556+
'bicubic' or 'trilinear'. Default: False
2557+
See also: https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html
2558+
allow_missing_keys: don't raise exception if key is missing.
2559+
device: device on which the tensor will be allocated.
2560+
2561+
See also:
2562+
- :py:class:`monai.transforms.compose.MapTransform`
2563+
2564+
"""
2565+
MapTransform.__init__(self, keys, allow_missing_keys)
2566+
RandomizableTransform.__init__(self, prob)
2567+
2568+
self.downsample_mode = downsample_mode
2569+
self.upsample_mode = upsample_mode
2570+
self.zoom_range = zoom_range
2571+
self.align_corners = align_corners
2572+
self.device = device
2573+
2574+
self.sim_lowres_tfm = RandSimulateLowResolution(
2575+
prob=1.0, # probability is handled by dictionary class
2576+
downsample_mode=self.downsample_mode,
2577+
upsample_mode=self.upsample_mode,
2578+
zoom_range=self.zoom_range,
2579+
align_corners=self.align_corners,
2580+
device=self.device,
2581+
)
2582+
2583+
def set_random_state(
2584+
self, seed: int | None = None, state: np.random.RandomState | None = None
2585+
) -> RandSimulateLowResolutiond:
2586+
super().set_random_state(seed, state)
2587+
return self
2588+
2589+
def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, NdarrayOrTensor]:
2590+
"""
2591+
Args:
2592+
data: a dictionary containing the tensor-like data to be transformed. The ``keys`` specified
2593+
in this dictionary must be tensor like arrays that are channel first and have at most
2594+
three spatial dimensions
2595+
"""
2596+
d = dict(data)
2597+
first_key: Hashable = self.first_key(d)
2598+
if first_key == ():
2599+
out: dict[Hashable, NdarrayOrTensor] = convert_to_tensor(d, track_meta=get_track_meta())
2600+
return out
2601+
2602+
self.randomize(None)
2603+
2604+
for key in self.key_iterator(d):
2605+
# do the transform
2606+
if self._do_transform:
2607+
d[key] = self.sim_lowres_tfm(d[key]) # type: ignore
2608+
else:
2609+
d[key] = convert_to_tensor(d[key], track_meta=get_track_meta(), dtype=torch.float32)
2610+
return d
2611+
2612+
25212613
SpatialResampleD = SpatialResampleDict = SpatialResampled
25222614
ResampleToMatchD = ResampleToMatchDict = ResampleToMatchd
25232615
SpacingD = SpacingDict = Spacingd
@@ -2541,3 +2633,4 @@ def __call__(self, data: Mapping[Hashable, NdarrayOrTensor]) -> dict[Hashable, N
25412633
GridSplitD = GridSplitDict = GridSplitd
25422634
GridPatchD = GridPatchDict = GridPatchd
25432635
RandGridPatchD = RandGridPatchDict = RandGridPatchd
2636+
RandSimulateLowResolutionD = RandSimulateLowResolutionDict = RandSimulateLowResolutiond
Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,83 @@
1+
# Copyright (c) MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
from __future__ import annotations
13+
14+
import unittest
15+
16+
import numpy as np
17+
from parameterized import parameterized
18+
19+
from monai.transforms import RandSimulateLowResolution
20+
from tests.utils import TEST_NDARRAYS, assert_allclose
21+
22+
TESTS = []
23+
for p in TEST_NDARRAYS:
24+
TESTS.append(
25+
[
26+
dict(prob=1.0, zoom_range=(0.8, 0.81)),
27+
p(
28+
np.array(
29+
[
30+
[
31+
[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]],
32+
[[16, 17, 18, 19], [20, 21, 22, 23], [24, 25, 26, 27], [28, 29, 30, 31]],
33+
[[32, 33, 34, 35], [36, 37, 38, 39], [40, 41, 42, 43], [44, 45, 46, 47]],
34+
[[48, 49, 50, 51], [52, 53, 54, 55], [56, 57, 58, 59], [60, 61, 62, 63]],
35+
]
36+
]
37+
)
38+
),
39+
np.array(
40+
[
41+
[
42+
[
43+
[0.0000, 0.6250, 1.3750, 2.0000],
44+
[2.5000, 3.1250, 3.8750, 4.5000],
45+
[5.5000, 6.1250, 6.8750, 7.5000],
46+
[8.0000, 8.6250, 9.3750, 10.0000],
47+
],
48+
[
49+
[10.0000, 10.6250, 11.3750, 12.0000],
50+
[12.5000, 13.1250, 13.8750, 14.5000],
51+
[15.5000, 16.1250, 16.8750, 17.5000],
52+
[18.0000, 18.6250, 19.3750, 20.0000],
53+
],
54+
[
55+
[22.0000, 22.6250, 23.3750, 24.0000],
56+
[24.5000, 25.1250, 25.8750, 26.5000],
57+
[27.5000, 28.1250, 28.8750, 29.5000],
58+
[30.0000, 30.6250, 31.3750, 32.0000],
59+
],
60+
[
61+
[32.0000, 32.6250, 33.3750, 34.0000],
62+
[34.5000, 35.1250, 35.8750, 36.5000],
63+
[37.5000, 38.1250, 38.8750, 39.5000],
64+
[40.0000, 40.6250, 41.3750, 42.0000],
65+
],
66+
]
67+
]
68+
),
69+
]
70+
)
71+
72+
73+
class TestRandGaussianSmooth(unittest.TestCase):
74+
@parameterized.expand(TESTS)
75+
def test_value(self, arguments, image, expected_data):
76+
randsimlowres = RandSimulateLowResolution(**arguments)
77+
randsimlowres.set_random_state(seed=0)
78+
result = randsimlowres(image)
79+
assert_allclose(result, expected_data, rtol=1e-4, type_test="tensor")
80+
81+
82+
if __name__ == "__main__":
83+
unittest.main()

0 commit comments

Comments
 (0)