Skip to content

Commit 9122a00

Browse files
wylijak0bw
authored andcommitted
support keep_size=True in lazy Zoom (Project-MONAI#6240)
adds the missing implementation here: https://github.com/Project-MONAI/MONAI/blob/795bf61adb4c25a03a355fa9c4552473eb763939/monai/transforms/spatial/functional.py#L447-L449 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [x] New tests added to cover the changes. - [x] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Wenqi Li <wenqil@nvidia.com>
1 parent 53ae200 commit 9122a00

File tree

3 files changed

+29
-10
lines changed

3 files changed

+29
-10
lines changed

monai/transforms/spatial/functional.py

Lines changed: 14 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from monai.transforms.utils import create_rotate, create_translate, scale_affine
3636
from monai.transforms.utils_pytorch_numpy_unification import allclose
3737
from monai.utils import (
38+
LazyAttr,
3839
TraceKeys,
3940
convert_to_dst_type,
4041
convert_to_numpy,
@@ -432,10 +433,7 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype,
432433
433434
"""
434435
im_shape = img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:]
435-
output_size = [
436-
int(math.floor(float(i) * z))
437-
for i, z in zip(img.peek_pending_shape() if isinstance(img, MetaTensor) else img.shape[1:], scale_factor)
438-
]
436+
output_size = [int(math.floor(float(i) * z)) for i, z in zip(im_shape, scale_factor)]
439437
xform = scale_affine(im_shape, output_size)
440438
extra_info = {
441439
"mode": mode,
@@ -445,9 +443,18 @@ def zoom(img, scale_factor, keep_size, mode, padding_mode, align_corners, dtype,
445443
"padcrop": {},
446444
}
447445
if keep_size:
448-
if transform_info.get(TraceKeys.LAZY_EVALUATION, False):
449-
raise NotImplementedError("keep_size=True is not supported for lazy evaluation.")
450-
output_size = [int(i) for i in img.shape[1:]]
446+
do_pad_crop = not np.allclose(output_size, im_shape)
447+
if do_pad_crop and transform_info.get(TraceKeys.LAZY_EVALUATION, False): # update for lazy evaluation
448+
_pad_crop = ResizeWithPadOrCrop(spatial_size=im_shape, mode=padding_mode)
449+
_pad_crop.lazy_evaluation = True
450+
_tmp_img = MetaTensor([], affine=torch.eye(len(output_size) + 1))
451+
_tmp_img.push_pending_operation({LazyAttr.SHAPE: list(output_size), LazyAttr.AFFINE: xform})
452+
lazy_cropped = _pad_crop(_tmp_img)
453+
if isinstance(lazy_cropped, MetaTensor):
454+
xform = lazy_cropped.peek_pending_affine()
455+
extra_info["padcrop"] = lazy_cropped.pending_operations[-1]
456+
extra_info["do_padcrop"] = do_pad_crop
457+
output_size = [int(i) for i in im_shape]
451458
meta_info = TraceableTransform.track_transform_meta(
452459
img,
453460
sp_size=output_size,

tests/test_integration_lazy_samples.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ def run_training_test(root_dir, device="cuda:0", cachedataset=0, readers=(None,
5959
keys=["img", "seg"], label_key="seg", spatial_size=[76, 82, 80], pos=1, neg=1, num_samples=4
6060
),
6161
mt.RandRotate90d(keys=["img", "seg"], prob=0.8, spatial_axes=(0, 2)),
62+
mt.RandZoomd(
63+
keys=["img", "seg"], prob=1.0, min_zoom=1.0, max_zoom=1.0, mode=("trilinear", 0), keep_size=True
64+
),
6265
mt.ResizeWithPadOrCropD(keys=["img", "seg"], spatial_size=[80, 72, 80]),
6366
mt.Rotated(keys=["img", "seg"], angle=[np.pi / 2, np.pi / 2, 0], mode="nearest", keep_size=False),
6467
],

tests/test_zoom.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,25 @@
2929
test_local_inversion,
3030
)
3131

32-
VALID_CASES = [(1.5, "nearest", True), (1.5, "nearest", False), (0.8, "bilinear"), (0.8, "area")]
32+
VALID_CASES = [
33+
(1.5, "nearest", True),
34+
(1.5, "nearest", False),
35+
(0.8, "bilinear"),
36+
(0.8, "area"),
37+
(1.5, "nearest", False, True),
38+
(0.8, "area", False, True),
39+
]
3340

3441
INVALID_CASES = [((None, None), "bilinear", TypeError), ((0.9, 0.9), "s", ValueError)]
3542

3643

3744
class TestZoom(NumpyImageTestCase2D):
3845
@parameterized.expand(VALID_CASES)
39-
def test_pending_ops(self, zoom, mode, align_corners=False):
46+
def test_pending_ops(self, zoom, mode, align_corners=False, keep_size=False):
4047
im = MetaTensor(self.imt[0], meta={"a": "b", "affine": DEFAULT_TEST_AFFINE})
41-
zoom_fn = Zoom(zoom=zoom, mode="bilinear", keep_size=False, dtype=torch.float64, align_corners=align_corners)
48+
zoom_fn = Zoom(
49+
zoom=zoom, mode="bilinear", keep_size=keep_size, dtype=torch.float64, align_corners=align_corners
50+
)
4251
# non-lazy
4352
expected = zoom_fn(im)
4453
self.assertIsInstance(expected, MetaTensor)

0 commit comments

Comments
 (0)