Skip to content

fixes #100 Transform synchronization #128

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Mar 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions monai/data/nifti_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.utils.module import export
from monai.transforms.compose import Randomizable


def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dtype=None):
Expand Down Expand Up @@ -106,19 +107,17 @@ def __getitem__(self, index):
image_only=self.image_only, dtype=self.dtype)
seg = load_nifti(self.seg_files[index])

# https://github.com/pytorch/vision/issues/9#issuecomment-304224800
seed = np.random.randint(2147483647)

if self.transform is not None:
np.random.seed(seed)
if isinstance(self.transform, Randomizable):
self.transform.set_random_state(seed=seed)
img = self.transform(img)
random_sync_test = np.random.randint(2147483647)

if self.seg_transform is not None:
np.random.seed(seed) # ensure randomized transforms roll the same values for segmentations as images
if isinstance(self.seg_transform, Randomizable):
self.seg_transform.set_random_state(seed=seed)
seg = self.seg_transform(seg)
seg_seed = np.random.randint(2147483647)
assert(random_sync_test == seg_seed)

if self.image_only or meta_data is None:
return img, seg
Expand Down
6 changes: 4 additions & 2 deletions monai/data/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from monai.transforms.utils import ensure_tuple_size


def get_random_patch(dims, patch_size):
def get_random_patch(dims, patch_size, rand_state=None):
"""
Returns a tuple of slices to define a random patch in an array of shape `dims` with size `patch_size` or the as
close to it as possible within the given dimension. It is expected that `patch_size` is a valid patch for a source
Expand All @@ -25,13 +25,15 @@ def get_random_patch(dims, patch_size):
Args:
dims (tuple of int): shape of source array
patch_size (tuple of int): shape of patch size to generate
rand_state (np.random.RandomState): a random state object to generate random numbers from

Returns:
(tuple of slice): a tuple of slice objects defining the patch
"""

# choose the minimal corner of the patch
min_corner = tuple(np.random.randint(0, ms - ps) if ms > ps else 0 for ms, ps in zip(dims, patch_size))
rand_int = np.random.randint if rand_state is None else rand_state.randint
min_corner = tuple(rand_int(0, ms - ps) if ms > ps else 0 for ms, ps in zip(dims, patch_size))

# create the slices for each dimension which define the patch in the source array
return tuple(slice(mc, mc + ps) for mc, ps in zip(min_corner, patch_size))
Expand Down
36 changes: 32 additions & 4 deletions monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from collections.abc import Hashable

import monai
from monai.data.utils import get_random_patch, get_valid_patch_size
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transforms import Rotate90, SpatialCrop
from monai.utils.misc import ensure_tuple
Expand Down Expand Up @@ -77,6 +78,33 @@ def __call__(self, data):
return d


@export
class UniformRandomPatchd(Randomizable, MapTransform):
"""
Selects a patch of the given size chosen at a uniformly random position in the image.
"""

def __init__(self, keys, patch_size):
MapTransform.__init__(self, keys)

self.patch_size = (None,) + tuple(patch_size)

self._slices = None

def randomize(self, image_shape, patch_shape):
self._slices = get_random_patch(image_shape, patch_shape, self.R)

def __call__(self, data):
d = dict(data)

image_shape = d[self.keys[0]].shape # image shape from the first data key
patch_size = get_valid_patch_size(image_shape, self.patch_size)
self.randomize(image_shape, patch_size)
for key in self.keys:
d[key] = d[key][self._slices]
return d


@export
class RandRotate90d(Randomizable, MapTransform):
"""
Expand Down Expand Up @@ -105,12 +133,12 @@ def __init__(self, keys, prob=0.1, max_k=3, axes=(1, 2)):
self._do_transform = False
self._rand_k = 0

def randomise(self):
def randomize(self):
self._rand_k = self.R.randint(self.max_k) + 1
self._do_transform = self.R.random() < self.prob

def __call__(self, data):
self.randomise()
self.randomize()
if not self._do_transform:
return data

Expand Down Expand Up @@ -155,13 +183,13 @@ def __init__(self, keys, label_key, size, pos=1, neg=1, num_samples=1):
self.num_samples = num_samples
self.centers = None

def randomise(self, label):
def randomize(self, label):
self.centers = generate_pos_neg_label_crop_centers(label, self.size, self.num_samples, self.pos_ratio, self.R)

def __call__(self, data):
d = dict(data)
label = d[self.label_key]
self.randomise(label)
self.randomize(label)
results = [dict() for _ in range(self.num_samples)]
for key in data.keys():
if key in self.keys:
Expand Down
105 changes: 63 additions & 42 deletions monai/transforms/compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import warnings

import numpy as np


Expand Down Expand Up @@ -36,7 +38,50 @@ def __call__(self, data):
raise NotImplementedError


class Compose:
class Randomizable:
"""
An interface for handling local numpy random state.
this is mainly for randomized data augmentation transforms.
"""
R = np.random.RandomState()

def set_random_state(self, seed=None, state=None):
"""
Set the random state locally, to control the randomness, the derived
classes should use `self.R` instead of `np.random` to introduce random
factors.

Args:
seed (int): set the random state with an integer seed.
state (np.random.RandomState): set the random state with a `np.random.RandomState` object.

Returns:
a Randomizable instance.
Note:
thread safety
"""
if seed is not None:
_seed = id(seed) if not isinstance(seed, int) else seed
self.R = np.random.RandomState(_seed)
return self

if state is not None:
if not isinstance(state, np.random.RandomState):
raise ValueError('`state` must be a `np.random.RandomState`, got {}'.format(type(state)))
self.R = state
return self

self.R = np.random.RandomState()
return self

def randomize(self):
"""
all self.R calls happen here so that we have a better chance to identify errors of sync the random state.
"""
raise NotImplementedError


class Compose(Randomizable):
"""
`Compose` provides the ability to chain a series of calls together in a
sequence. Each transform in the sequence must take a single argument and
Expand Down Expand Up @@ -97,6 +142,23 @@ def __init__(self, transforms=None):
raise ValueError("Parameters 'transforms' must be a list or tuple")
self.transforms = transforms

def set_random_state(self, seed=None, state=None):
for _transform in self.transforms:
if not isinstance(_transform, Randomizable):
continue
_transform.set_random_state(seed, state)

def randomize(self):
for _transform in self.transforms:
if not isinstance(_transform, Randomizable):
continue
try:
_transform.randomize()
except TypeError as type_error:
warnings.warn(
'Transform "{0}" in Compose not randomized\n{0}.{1}.'.format(type(_transform).__name__, type_error),
RuntimeWarning)

def __call__(self, input_):
for transform in self.transforms:
# if some transform generated batch list of data in the transform chain,
Expand All @@ -107,44 +169,3 @@ def __call__(self, input_):
else:
input_ = transform(input_)
return input_


class Randomizable:
"""
An interface for handling local numpy random state.
this is mainly for randomized data augmentation transforms.
"""
R = np.random.RandomState()

def set_random_state(self, seed=None, state=None):
"""
Set the random state locally, to control the randomness, the derived
classes should use `self.R` instead of `np.random` to introduce random
factors.

Args:
seed (int): set the random state with an integer seed.
state (np.random.RandomState): set the random state with a `np.random.RandomState` object.

Note:
thread safety
"""
if seed is not None:
_seed = id(seed) if not isinstance(seed, int) else seed
self.R = np.random.RandomState(_seed)
return

if state is not None:
if not isinstance(state, np.random.RandomState):
raise ValueError('`state` must be a `np.random.RandomState`, got {}'.format(type(state)))
self.R = state
return

self.R = np.random.RandomState()
return

def randomise(self):
"""
all self.R calls happen here so that we have a better chance to identify errors of sync the random state.
"""
raise NotImplementedError
16 changes: 10 additions & 6 deletions monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,19 +91,23 @@ def __call__(self, img):


@export
class UniformRandomPatch:
class UniformRandomPatch(Randomizable):
"""
Selects a patch of the given size chosen at a uniformly random position in the image.
"""

def __init__(self, patch_size):
self.patch_size = (None,) + tuple(patch_size)

self._slices = None

def randomize(self, image_shape, patch_shape):
self._slices = get_random_patch(image_shape, patch_shape, self.R)

def __call__(self, img):
patch_size = get_valid_patch_size(img.shape, self.patch_size)
slices = get_random_patch(img.shape, patch_size)

return img[slices]
self.randomize(img.shape, patch_size)
return img[self._slices]


@export
Expand Down Expand Up @@ -212,12 +216,12 @@ def __init__(self, prob=0.1, max_k=3, axes=(1, 2)):
self._do_transform = False
self._rand_k = 0

def randomise(self):
def randomize(self):
self._rand_k = self.R.randint(self.max_k) + 1
self._do_transform = self.R.random() < self.prob

def __call__(self, img):
self.randomise()
self.randomize()
if not self._do_transform:
return img
rotator = Rotate90(self._rand_k, self.axes)
Expand Down
33 changes: 32 additions & 1 deletion tests/test_compose.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import unittest

from monai.transforms.compose import Compose
from monai.transforms.compose import Compose, Randomizable


class TestCompose(unittest.TestCase):
Expand Down Expand Up @@ -67,6 +67,37 @@ def c(d): # transform to handle dict data
for item in value:
self.assertDictEqual(item, {'a': 2, 'b': 1, 'c': 2})

def test_random_compose(self):

class _Acc(Randomizable):
self.rand = 0.0

def randomize(self):
self.rand = self.R.rand()

def __call__(self, data):
self.randomize()
return self.rand + data

c = Compose([_Acc(), _Acc()])
self.assertNotAlmostEqual(c(0), c(0))
c.set_random_state(123)
self.assertAlmostEqual(c(1), 2.39293837)
c.set_random_state(223)
c.randomize()
self.assertAlmostEqual(c(1), 2.57673391)

def test_randomize_warn(self):

class _RandomClass(Randomizable):

def randomize(self, foo):
pass

c = Compose([_RandomClass(), _RandomClass()])
with self.assertWarns(Warning):
c.randomize()


if __name__ == '__main__':
unittest.main()
30 changes: 30 additions & 0 deletions tests/test_uniform_rand_patch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import numpy as np

from monai.transforms.transforms import UniformRandomPatch
from tests.utils import NumpyImageTestCase2D


class UniformRandomPatchTest(NumpyImageTestCase2D):

def test_2d(self):
patch_size = (1, 10, 10)
patch_transform = UniformRandomPatch(patch_size=patch_size)
patch = patch_transform(self.imt)
self.assertTrue(np.allclose(patch.shape[:-2], patch_size[:-2]))


if __name__ == '__main__':
unittest.main()
Loading