Skip to content

Commit 445a568

Browse files
authored
107-zoom (#138)
* Adding zoom transform and tests. * Fix import error.
1 parent 7f29477 commit 445a568

File tree

2 files changed

+125
-0
lines changed

2 files changed

+125
-0
lines changed

monai/transforms/transforms.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,67 @@ def __call__(self, img):
177177
prefilter=self.prefilter)
178178

179179

180+
@export
181+
class Zoom:
182+
""" Zooms a nd image. Uses scipy.ndimage.zoom or cupyx.scipy.ndimage.zoom in case of gpu.
183+
For details, please see https://docs.scipy.org/doc/scipy/reference/generated/scipy.ndimage.zoom.html.
184+
185+
Args:
186+
zoom (float or sequence): The zoom factor along the axes. If a float, zoom is the same for each axis.
187+
If a sequence, zoom should contain one value for each axis.
188+
order (int): order of interpolation. Default=3.
189+
mode (str): Determines how input is extended beyond boundaries. Default is 'constant'.
190+
cval (scalar, optional): Value to fill past edges. Default is 0.
191+
use_gpu (bool): Should use cpu or gpu.
192+
keep_size (bool): Should keep original size (pad if needed).
193+
"""
194+
def __init__(self, zoom, order=3, mode='constant', cval=0, prefilter=True, use_gpu=False, keep_size=False):
195+
assert isinstance(order, int), "Order must be integer."
196+
self.zoom = zoom
197+
self.order = order
198+
self.mode = mode
199+
self.cval = cval
200+
self.prefilter = prefilter
201+
self.use_gpu = use_gpu
202+
self.keep_size = keep_size
203+
204+
def __call__(self, img):
205+
zoomed = None
206+
if self.use_gpu:
207+
try:
208+
import cupy
209+
from cupyx.scipy.ndimage import zoom as zoom_gpu
210+
211+
zoomed_gpu = zoom_gpu(cupy.array(img), zoom=self.zoom, order=self.order,
212+
mode=self.mode, cval=self.cval, prefilter=self.prefilter)
213+
zoomed = cupy.asnumpy()
214+
except ModuleNotFoundError:
215+
print('For GPU zoom, please install cupy. Defaulting to cpu.')
216+
except Exception:
217+
print('Warning: Zoom gpu failed. Defaulting to cpu.')
218+
219+
if not zoomed or not self.use_gpu:
220+
zoomed = scipy.ndimage.zoom(img, zoom=self.zoom, order=self.order,
221+
mode=self.mode, cval=self.cval, prefilter=self.prefilter)
222+
223+
# Crops to original size or pads.
224+
if self.keep_size:
225+
shape = img.shape
226+
pad_vec = [[0, 0]] * len(shape)
227+
crop_vec = list(zoomed.shape)
228+
for d in range(len(shape)):
229+
if zoomed.shape[d] > shape[d]:
230+
crop_vec[d] = shape[d]
231+
elif zoomed.shape[d] < shape[d]:
232+
# pad_vec[d] = [0, shape[d] - zoomed.shape[d]]
233+
pad_h = (float(shape[d]) - float(zoomed.shape[d])) / 2
234+
pad_vec[d] = [int(np.floor(pad_h)), int(np.ceil(pad_h))]
235+
zoomed = zoomed[0:crop_vec[0], 0:crop_vec[1], 0:crop_vec[2]]
236+
zoomed = np.pad(zoomed, pad_vec, mode='constant', constant_values=self.cval)
237+
238+
return zoomed
239+
240+
180241
@export
181242
class ToTensor:
182243
"""

tests/test_zoom.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import numpy as np
15+
from scipy.ndimage import zoom as zoom_scipy
16+
from parameterized import parameterized
17+
18+
from monai.transforms import Zoom
19+
from tests.utils import NumpyImageTestCase2D
20+
21+
22+
class ZoomTest(NumpyImageTestCase2D):
23+
24+
@parameterized.expand([
25+
(1.1, 3, 'constant', 0, True, False, False),
26+
(0.9, 3, 'constant', 0, True, False, False),
27+
(0.8, 1, 'reflect', 0, False, False, False)
28+
])
29+
def test_correct_results(self, zoom, order, mode, cval, prefilter, use_gpu, keep_size):
30+
zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval,
31+
prefilter=prefilter, use_gpu=use_gpu, keep_size=keep_size)
32+
zoomed = zoom_fn(self.imt)
33+
expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order,
34+
cval=cval, prefilter=prefilter)
35+
self.assertTrue(np.allclose(expected, zoomed))
36+
37+
@parameterized.expand([
38+
("gpu_zoom", 0.6, 3, 'constant', 0, True)
39+
])
40+
def test_gpu_zoom(self, _, zoom, order, mode, cval, prefilter):
41+
zoom_fn = Zoom(zoom=zoom, order=order, mode=mode, cval=cval,
42+
prefilter=prefilter, use_gpu=True, keep_size=False)
43+
zoomed = zoom_fn(self.imt)
44+
expected = zoom_scipy(self.imt, zoom=zoom, mode=mode, order=order,
45+
cval=cval, prefilter=prefilter)
46+
self.assertTrue(np.allclose(expected, zoomed))
47+
48+
def test_keep_size(self):
49+
zoom_fn = Zoom(zoom=0.6, keep_size=True)
50+
zoomed = zoom_fn(self.imt)
51+
self.assertTrue(np.array_equal(zoomed.shape, self.imt.shape))
52+
53+
@parameterized.expand([
54+
("no_zoom", None, 1, TypeError),
55+
("invalid_order", 0.9, 's', AssertionError)
56+
])
57+
def test_invalid_inputs(self, _, zoom, order, raises):
58+
with self.assertRaises(raises):
59+
zoom_fn = Zoom(zoom=zoom, order=order)
60+
zoomed = zoom_fn(self.imt)
61+
62+
63+
if __name__ == '__main__':
64+
unittest.main()

0 commit comments

Comments
 (0)