Skip to content

Commit

Permalink
152-random-rotate (#155)
Browse files Browse the repository at this point in the history
* Add RandomRotate.
  • Loading branch information
madil90 authored Mar 10, 2020
1 parent 5c49f8f commit d166f6a
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 1 deletion.
1 change: 0 additions & 1 deletion monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ def __init__(self, keys, affine_key, pixdim, interp_order=2, keep_shape=False, o
interp_order = ensure_tuple(interp_order)
self.interp_order = interp_order \
if len(interp_order) == len(self.keys) else interp_order * len(self.keys)
print(self.interp_order)
self.output_key = output_key

def __call__(self, data):
Expand Down
50 changes: 50 additions & 0 deletions monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,6 +631,56 @@ def __call__(self, img):
return data


@export
class RandRotate(Randomizable):
"""Randomly rotates the input arrays.
Args:
prob (float): Probability of rotation.
degrees (tuple of float or float): Range of rotation in degrees. If single number,
angle is picked from (-degrees, degrees).
axes (tuple of 2 ints): Axes of rotation. Default: (1, 2). This is the first two
axis in spatial dimensions according to MONAI channel first shape assumption.
reshape (bool): If true, output shape is made same as input. Default: True.
order (int): Order of spline interpolation. Range 0-5. Default: 1. This is
different from scipy where default interpolation is 3.
mode (str): Points outside boundary filled according to this mode. Options are
'constant', 'nearest', 'reflect', 'wrap'. Default: 'constant'.
cval (scalar): Value to fill outside boundary. Default: 0.
prefiter (bool): Apply spline_filter before interpolation. Default: True.
"""

def __init__(self, degrees, prob=0.1, axes=(1, 2), reshape=True, order=1,
mode='constant', cval=0, prefilter=True):
self.prob = prob
self.degrees = degrees
self.reshape = reshape
self.order = order
self.mode = mode
self.cval = cval
self.prefilter = prefilter
self.axes = axes

if not hasattr(self.degrees, '__iter__'):
self.degrees = (-self.degrees, self.degrees)
assert len(self.degrees) == 2, "degrees should be a number or pair of numbers."

self._do_transform = False
self.angle = None

def randomize(self):
self._do_transform = self.R.random_sample() < self.prob
self.angle = self.R.uniform(low=self.degrees[0], high=self.degrees[1])

def __call__(self, img):
self.randomize()
if not self._do_transform:
return img
rotator = Rotate(self.angle, self.axes, self.reshape, self.order,
self.mode, self.cval, self.prefilter)
return rotator(img)


@export
class RandomFlip(Randomizable):
"""Randomly flips the image along axes.
Expand Down
43 changes: 43 additions & 0 deletions tests/test_random_rotate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import numpy as np

import scipy.ndimage
from parameterized import parameterized

from monai.transforms import RandRotate
from tests.utils import NumpyImageTestCase2D


class RandomRotateTest(NumpyImageTestCase2D):

@parameterized.expand([
(90, (1, 2), True, 1, 'reflect', 0, True),
((-45, 45), (2, 1), True, 3, 'constant', 0, True),
(180, (2, 3), False, 2, 'constant', 4, False),
])
def test_correct_results(self, degrees, axes, reshape,
order, mode, cval, prefilter):
rotate_fn = RandRotate(degrees, prob=1.0, axes=axes, reshape=reshape,
order=order, mode=mode, cval=cval, prefilter=prefilter)
rotate_fn.set_random_state(243)
rotated = rotate_fn(self.imt)

angle = rotate_fn.angle
expected = scipy.ndimage.rotate(self.imt, angle, axes, reshape, order=order,
mode=mode, cval=cval, prefilter=prefilter)
self.assertTrue(np.allclose(expected, rotated))


if __name__ == '__main__':
unittest.main()

0 comments on commit d166f6a

Please sign in to comment.