diff --git a/changelogs/master/improved/20200315_segment_replacement.md b/changelogs/master/improved/20200315_segment_replacement.md new file mode 100644 index 000000000..434e79856 --- /dev/null +++ b/changelogs/master/improved/20200315_segment_replacement.md @@ -0,0 +1,25 @@ +# Improved Performance of Segment Replacement #640 + +This patch improves the performance of segment +replacement (by average colors within the segments), +used in `Superpixels` and `segment_voronoi()`. +The new method is up to around 7x faster, more for +smaller images and more segments. It can be slightly +slower in some cases for large images (512x512 and +larger). + +This change seems to improve the overall performance +of `Superpixels` by a factor of around 1.1x to 1.4x +(more for smaller images). +It improves the overall performance of +`segment_voronoi()` by about 1.1x to 2.0x and can +reach much higher improvements in the case of very few +segments that have to be replaced. + +Note that `segment_voronoi()` is used in `Voronoi`. + +Added functions: +* `imgaug.augmenters.segmentation.replace_segments_` + +Added classes: +* `imgaug.testutils.temporary_constants` (context) diff --git a/checks/check_voronoi.py b/checks/check_voronoi.py index 93be666f7..1ec866e73 100644 --- a/checks/check_voronoi.py +++ b/checks/check_voronoi.py @@ -15,19 +15,21 @@ def main(): ) uniform_sampler = iaa.UniformPointsSampler(50*50) - augs = [ - iaa.Voronoi(points_sampler=reggrid_sampler, p_replace=1.0, - max_size=128), - iaa.Voronoi(points_sampler=uniform_sampler, p_replace=1.0, - max_size=128), - iaa.UniformVoronoi(50*50, p_replace=1.0, max_size=128), - iaa.RegularGridVoronoi(50, 50, p_drop_points=0.4, p_replace=1.0, - max_size=128), - ] - - images = [aug(image=image) for aug in augs] - - ia.imshow(np.hstack(images)) + for p_replace in [1.0, 0.5, 0.1, 0.0]: + augs = [ + iaa.Voronoi(points_sampler=reggrid_sampler, p_replace=p_replace, + max_size=128), + iaa.Voronoi(points_sampler=uniform_sampler, p_replace=p_replace, + max_size=128), + iaa.UniformVoronoi(50*50, p_replace=p_replace, max_size=128), + iaa.RegularGridVoronoi(50, 50, p_drop_points=0.4, + p_replace=p_replace, max_size=128), + iaa.RelativeRegularGridVoronoi(p_replace=p_replace, max_size=128) + ] + + images = [aug(image=image) for aug in augs] + + ia.imshow(np.hstack(images)) if __name__ == "__main__": diff --git a/imgaug/augmenters/segmentation.py b/imgaug/augmenters/segmentation.py index e808f7a83..6b9cf0fcd 100644 --- a/imgaug/augmenters/segmentation.py +++ b/imgaug/augmenters/segmentation.py @@ -20,6 +20,7 @@ # with skimage.segmentation for whatever reason import skimage.segmentation import skimage.measure +import scipy.ndimage as ndimage import six import six.moves as sm @@ -30,6 +31,10 @@ from .. import dtypes as iadt +_REPLACE_SEGMENTS_NP_BELOW_AREA = 64 * 64 +_REPLACE_SEGMENTS_NP_BELOW_NSEG = 25 + + # TODO merge this into imresize? def _ensure_image_max_size(image, max_size, interpolation): """Ensure that images do not exceed a required maximum sidelength. @@ -271,7 +276,9 @@ def _augment_batch_(self, batch, random_state, parents, hooks): segments = skimage.segmentation.slic( image, n_segments=n_segments_samples[i], compactness=10) - image_aug = self._replace_segments(image, segments, replace_samples) + image_aug = replace_segments_( + image, segments, replace_samples > 0.5 + ) if orig_shape != image_aug.shape: image_aug = ia.imresize_single_image( @@ -282,52 +289,183 @@ def _augment_batch_(self, batch, random_state, parents, hooks): batch.images[i] = image_aug return batch - @classmethod - def _replace_segments(cls, image, segments, replace_samples): - min_value, _center_value, max_value = \ - iadt.get_value_range_of_dtype(image.dtype) - image_sp = np.copy(image) - - nb_channels = image.shape[2] - for c in sm.xrange(nb_channels): - # segments+1 here because otherwise regionprops always - # misses the last label - regions = skimage.measure.regionprops( - segments+1, intensity_image=image[..., c]) - for ridx, region in enumerate(regions): - # with mod here, because slic can sometimes create more - # superpixel than requested. replace_samples then does not - # have enough values, so we just start over with the first one - # again. - if replace_samples[ridx % len(replace_samples)] > 0.5: - mean_intensity = region.mean_intensity - image_sp_c = image_sp[..., c] - - if image_sp_c.dtype.kind in ["i", "u", "b"]: - # After rounding the value can end up slightly outside - # of the value_range. Hence, we need to clip. We do - # clip via min(max(...)) instead of np.clip because - # the latter one does not seem to keep dtypes for - # dtypes with large itemsizes (e.g. uint64). - value = int(np.round(mean_intensity)) - value = min(max(value, min_value), max_value) - else: - value = mean_intensity - - image_sp_c[segments == ridx] = value - - return image_sp - def get_parameters(self): """See :func:`~imgaug.augmenters.meta.Augmenter.get_parameters`.""" return [self.p_replace, self.n_segments, self.max_size, self.interpolation] +# TODO add the old skimage method here for 512x512+ images as it starts to +# be faster for these areas +# TODO incorporate this dtype support in the dtype sections of docstrings for +# Superpixels and segment_voronoi() +def replace_segments_(image, segments, replace_flags): + """Replace segments in images by their average colors in-place. + + This expects an image ``(H,W,[C])`` and an integer segmentation + map ``(H,W)``. The segmentation map must contain the same id for pixels + that are supposed to be replaced by the same color ("segments"). + For each segement, the average color is computed and used as the + replacement. + + Added in 0.5.0. + + **Supported dtypes**: + + * ``uint8``: yes; indirectly tested + * ``uint16``: yes; indirectly tested + * ``uint32``: yes; indirectly tested + * ``uint64``: no; not tested + * ``int8``: yes; indirectly tested + * ``int16``: yes; indirectly tested + * ``int32``: yes; indirectly tested + * ``int64``: no; not tested + * ``float16``: ?; not tested + * ``float32``: ?; not tested + * ``float64``: ?; not tested + * ``float128``: ?; not tested + * ``bool``: yes; indirectly tested + + Parameters + ---------- + image : ndarray + An image of shape ``(H,W,[C])``. + This image may be changed in-place. + The function is currently not tested for float dtypes. + + segments : ndarray + A ``(H,W)`` integer array containing the same ids for pixels belonging + to the same segment. + + replace_flags : ndarray or None + A boolean array containing at the ``i`` th index a flag denoting + whether the segment with id ``i`` should be replaced by its average + color. If the flag is ``False``, the original image pixels will be + kept unchanged for that flag. + If this is ``None``, all segments will be replaced. + + Returns + ------- + ndarray + The image with replaced pixels. + Might be the same image as was provided via `image`. + + """ + assert replace_flags is None or replace_flags.dtype.kind == "b" + + input_shape = image.shape + if 0 in image.shape: + return image + + if len(input_shape) == 2: + image = image[:, :, np.newaxis] + + nb_segments = None + func = _replace_segments_scipy_ + bad_dtype = image.dtype.name not in ["uint8", "int8"] + area = image.shape[0] * image.shape[1] + if bad_dtype or area < _REPLACE_SEGMENTS_NP_BELOW_AREA: + func = _replace_segments_np_ + else: + max_id = np.max(segments) + nb_segments = 1 + max_id + if nb_segments < _REPLACE_SEGMENTS_NP_BELOW_NSEG: + func = _replace_segments_np_ + + result = func(image, segments, replace_flags, nb_segments) + + if len(input_shape) == 2: + return result[:, :, 0] + return result + + +# Added in 0.5.0. +def _replace_segments_np_(image, segments, replace_flags, _nb_segments): + seg_ids = np.unique(segments) + if replace_flags is None: + replace_flags = [True] * len(seg_ids) + for i, seg_id in enumerate(seg_ids): + if replace_flags[i % len(replace_flags)]: + mask = (segments == seg_id) + mean_color = np.average(image[mask, :], axis=(0,)) + image[mask] = mean_color + return image + + +# Added in 0.5.0. +def _replace_segments_scipy_(image, segments, replace_flags, nb_segments): + # Generate segment ids of the segments to actually replace. + # Use "...[0:nb_segments]" here, because we can sample more flags than + # segments. + seg_ids = np.arange(nb_segments) + if replace_flags is not None: + replace_flags = np.resize(replace_flags, (nb_segments,)) + seg_ids = seg_ids[replace_flags] + if len(seg_ids) == 0: + return image + + if len(seg_ids) == nb_segments: + mask = np.full(segments.shape, True, dtype=np.bool) + segments_to_replace = segments.flat + image_to_replace = image.reshape((-1, image.shape[-1])) + else: + mask = np.isin(segments, seg_ids) + segments_to_replace = segments[mask] + image_to_replace = image[mask, :] + + seg_id_to_intensity = np.full((nb_segments,), 0, dtype=np.uint8) + + for c in sm.xrange(image.shape[2]): + # This returns a new array of same length as "seg_ids". Each value is + # the mean intensity of that segment in "image[..., c]". + labelwise_intensities = ndimage.labeled_comprehension( + image_to_replace[:, c], + segments_to_replace, + seg_ids, + np.mean, + np.uint8, + 0 + ) + + # we could call "seg_id_to_intensity *= 0" here, but that isn't really + # necessary as we set the values of all segments that we actually use + seg_id_to_intensity[seg_ids] = labelwise_intensities + + # doesn't seem to work here to use `image_to_replace[:, c] = ...` + # instead + image[mask, c] = seg_id_to_intensity[segments_to_replace] + return image + + # TODO don't average the alpha channel for RGBA? def segment_voronoi(image, cell_coordinates, replace_mask=None): """Average colors within voronoi cells of an image. + **Supported dtypes**: + + if (image size <= max_size): + + * ``uint8``: yes; fully tested + * ``uint16``: no; not tested + * ``uint32``: no; not tested + * ``uint64``: no; not tested + * ``int8``: no; not tested + * ``int16``: no; not tested + * ``int32``: no; not tested + * ``int64``: no; not tested + * ``float16``: no; not tested + * ``float32``: no; not tested + * ``float64``: no; not tested + * ``float128``: no; not tested + * ``bool``: no; not tested + + if (image size > max_size): + + minimum of ( + ``imgaug.augmenters.segmentation.Voronoi(image size <= max_size)``, + :func:`~imgaug.augmenters.segmentation._ensure_image_max_size` + ) + Parameters ---------- image : ndarray @@ -364,14 +502,13 @@ def segment_voronoi(image, cell_coordinates, replace_mask=None): return image height, width = image.shape[0:2] - pixel_coords, ids_of_nearest_cells = \ + ids_of_nearest_cells = \ _match_pixels_with_voronoi_cells(height, width, cell_coordinates) - cell_colors = _compute_avg_segment_colors( - image, pixel_coords, ids_of_nearest_cells, - len(cell_coordinates)) - - image_aug = _render_segments(image, ids_of_nearest_cells, cell_colors, - replace_mask) + image_aug = replace_segments_( + image, + ids_of_nearest_cells.reshape(image.shape[0:2]), + replace_mask + ) if input_dims == 2: return image_aug[..., 0] @@ -385,7 +522,7 @@ def _match_pixels_with_voronoi_cells(height, width, cell_coordinates): pixel_coords = _generate_pixel_coords(height, width) pixel_coords_subpixel = pixel_coords.astype(np.float32) + 0.5 ids_of_nearest_cells = tree.query(pixel_coords_subpixel)[1] - return pixel_coords, ids_of_nearest_cells + return ids_of_nearest_cells def _generate_pixel_coords(height, width): @@ -393,54 +530,6 @@ def _generate_pixel_coords(height, width): return np.c_[xx.ravel(), yy.ravel()] -def _compute_avg_segment_colors(image, pixel_coords, ids_of_nearest_segments, - nb_segments): - nb_channels = image.shape[2] - cell_colors = np.zeros((nb_segments, nb_channels), dtype=np.float64) - cell_counters = np.zeros((nb_segments,), dtype=np.uint32) - - # TODO vectorize - for pixel_coord, id_of_nearest_cell in zip(pixel_coords, - ids_of_nearest_segments): - # pixel_coord is (x,y), so we have to swap it to access the HxW image - pixel_coord_yx = pixel_coord[::-1] - cell_colors[id_of_nearest_cell] += image[tuple(pixel_coord_yx)] - cell_counters[id_of_nearest_cell] += 1 - - # cells without associated pixels can have a count of 0, we clip - # here to 1 as the result for these cells doesn't matter - cell_counters = np.clip(cell_counters, 1, None) - - cell_colors = cell_colors / cell_counters[:, np.newaxis] - - return cell_colors.astype(np.uint8) - - -def _render_segments(image, ids_of_nearest_segments, avg_segment_colors, - replace_mask): - ids_of_nearest_segments = np.copy(ids_of_nearest_segments) - height, width, nb_channels = image.shape - - # without replace_mask we could reduce this down to: - # data = cell_colors[ids_of_nearest_cells, :].reshape( - # (width, height, 3)) - # data = np.transpose(data, (1, 0, 2)) - - keep_mask = (~replace_mask) if replace_mask is not None else None - if keep_mask is None or not np.any(keep_mask): - data = avg_segment_colors[ids_of_nearest_segments, :] - else: - ids_to_keep = np.nonzero(keep_mask)[0] - indices_to_keep = np.where( - np.isin(ids_of_nearest_segments, ids_to_keep))[0] - data = avg_segment_colors[ids_of_nearest_segments, :] - - image_data = image.reshape((height*width, -1)) - data[indices_to_keep] = image_data[indices_to_keep, :] - data = data.reshape((height, width, nb_channels)) - return data - - # TODO this can be reduced down to a similar problem as Superpixels: # generate an integer-based class id map of segments, then replace all # segments with the same class id by the average color within that @@ -467,28 +556,7 @@ class Voronoi(meta.Augmenter): **Supported dtypes**: - if (image size <= max_size): - - * ``uint8``: yes; fully tested - * ``uint16``: no; not tested - * ``uint32``: no; not tested - * ``uint64``: no; not tested - * ``int8``: no; not tested - * ``int16``: no; not tested - * ``int32``: no; not tested - * ``int64``: no; not tested - * ``float16``: no; not tested - * ``float32``: no; not tested - * ``float64``: no; not tested - * ``float128``: no; not tested - * ``bool``: no; not tested - - if (image size > max_size): - - minimum of ( - ``imgaug.augmenters.segmentation.Voronoi(image size <= max_size)``, - :func:`~imgaug.augmenters.segmentation._ensure_image_max_size` - ) + See :func:`imgaug.augmenters.segmentation.segment_voronoi`. Parameters ---------- diff --git a/imgaug/testutils.py b/imgaug/testutils.py index 817a7013e..3ff996b29 100644 --- a/imgaug/testutils.py +++ b/imgaug/testutils.py @@ -12,6 +12,7 @@ import shutil import re import sys +import importlib import numpy as np import six.moves as sm @@ -360,3 +361,40 @@ def assertWarns(testcase, expected_warning, *args, **kwargs): # pylint: disable=invalid-name context = _AssertWarnsContext(expected_warning, testcase) return context.handle("assertWarns", args, kwargs) + + +class temporary_constants(object): + """Context to temporarily change the value of one or more constants. + + Added in 0.5.0. + + """ + + # pylint: disable=invalid-name + + UNCHANGED = object() + + def __init__(self, paths, values): + if ia.is_string(paths): + paths = [paths] + values = [values] + + self.paths = [".".join(path_i.split(".")[:-1]) for path_i in paths] + self.cnames = [path_i.split(".")[-1] for path_i in paths] + self.values = values + self.old_values = None + + def __enter__(self): + old_values = [] + for path, cname, value in zip(self.paths, self.cnames, self.values): + module = importlib.import_module(path) + old_values.append(getattr(module, cname)) + if value is not temporary_constants.UNCHANGED: + setattr(module, cname, value) + self.old_values = old_values + + def __exit__(self, exc_type, exc_val, exc_tb): + gen = zip(self.paths, self.cnames, self.old_values) + for path, cname, old_value in gen: + module = importlib.import_module(path) + setattr(module, cname, old_value) diff --git a/test/augmenters/test_segmentation.py b/test/augmenters/test_segmentation.py index b03a31803..c72e26fbe 100644 --- a/test/augmenters/test_segmentation.py +++ b/test/augmenters/test_segmentation.py @@ -2,6 +2,7 @@ import sys import warnings +import itertools # unittest only added in 3.4 self.subTest() if sys.version_info[0] < 3 or sys.version_info[1] < 4: import unittest2 as unittest @@ -20,7 +21,27 @@ from imgaug import parameters as iap from imgaug import dtypes as iadt from imgaug import random as iarandom -from imgaug.testutils import reseed, runtest_pickleable_uint8_img +from imgaug.testutils import ( + reseed, runtest_pickleable_uint8_img, temporary_constants +) + + +def _create_replace_np_context(use_np_replace): + module_name = "imgaug.augmenters.segmentation." + cnames = [ + module_name + "_REPLACE_SEGMENTS_NP_BELOW_AREA", + module_name + "_REPLACE_SEGMENTS_NP_BELOW_NSEG" + ] + if use_np_replace is True: + values = [10000 * 10000, 10000] + elif use_np_replace is False: + values = [0, 0] + else: + assert use_np_replace == "auto" + values = [temporary_constants.UNCHANGED, + temporary_constants.UNCHANGED] + + return temporary_constants(cnames, values) class TestSuperpixels(unittest.TestCase): @@ -70,56 +91,76 @@ def base_img_superpixels_right(self): return base_img_superpixels_right def test_p_replace_0_n_segments_2(self): - aug = iaa.Superpixels(p_replace=0, n_segments=2) - observed = aug.augment_image(self.base_img) - expected = self.base_img - assert np.allclose(observed, expected) + for use_np_replace in [True, False]: + with self.subTest(use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + aug = iaa.Superpixels(p_replace=0, n_segments=2) + observed = aug.augment_image(self.base_img) + expected = self.base_img + assert np.allclose(observed, expected) def test_p_replace_1_n_segments_2(self): - aug = iaa.Superpixels(p_replace=1.0, n_segments=2) - observed = aug.augment_image(self.base_img) - expected = self.base_img_superpixels - assert self._array_equals_tolerant(observed, expected, 2) + for use_np_replace in [True, False]: + with self.subTest(use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + aug = iaa.Superpixels(p_replace=1.0, n_segments=2) + observed = aug.augment_image(self.base_img) + expected = self.base_img_superpixels + assert self._array_equals_tolerant(observed, expected, 2) def test_p_replace_1_n_segments_stochastic_parameter(self): - aug = iaa.Superpixels(p_replace=1.0, n_segments=iap.Deterministic(2)) - observed = aug.augment_image(self.base_img) - expected = self.base_img_superpixels - assert self._array_equals_tolerant(observed, expected, 2) + for use_np_replace in [True, False]: + with self.subTest(use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + aug = iaa.Superpixels( + p_replace=1.0, n_segments=iap.Deterministic(2) + ) + observed = aug.augment_image(self.base_img) + expected = self.base_img_superpixels + assert self._array_equals_tolerant(observed, expected, 2) def test_p_replace_stochastic_parameter_n_segments_2(self): - aug = iaa.Superpixels( - p_replace=iap.Binomial(iap.Choice([0.0, 1.0])), n_segments=2) - observed = aug.augment_image(self.base_img) - assert ( - np.allclose(observed, self.base_img) - or self._array_equals_tolerant( - observed, self.base_img_superpixels, 2) - ) + for use_np_replace in [True, False]: + with self.subTest(use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + aug = iaa.Superpixels( + p_replace=iap.Binomial(iap.Choice([0.0, 1.0])), + n_segments=2 + ) + observed = aug.augment_image(self.base_img) + assert ( + np.allclose(observed, self.base_img) + or self._array_equals_tolerant( + observed, self.base_img_superpixels, 2) + ) def test_p_replace_050_n_segments_2(self): - aug = iaa.Superpixels(p_replace=0.5, n_segments=2) - seen = {"none": False, "left": False, "right": False, "both": False} - for _ in sm.xrange(100): - observed = aug.augment_image(self.base_img) - if self._array_equals_tolerant(observed, self.base_img, 2): - seen["none"] = True - elif self._array_equals_tolerant( - observed, self.base_img_superpixels_left, 2): - seen["left"] = True - elif self._array_equals_tolerant( - observed, self.base_img_superpixels_right, 2): - seen["right"] = True - elif self._array_equals_tolerant( - observed, self.base_img_superpixels, 2): - seen["both"] = True - else: - raise Exception( - "Generated superpixels image does not match any " - "expected image.") - if np.all(seen.values()): - break - assert np.all(seen.values()) + _eq = self._array_equals_tolerant + + for use_np_replace in [True, False]: + with self.subTest(use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + aug = iaa.Superpixels(p_replace=0.5, n_segments=2) + seen = {"none": False, "left": False, "right": False, + "both": False} + for _ in sm.xrange(100): + observed = aug.augment_image(self.base_img) + if _eq(observed, self.base_img, 2): + seen["none"] = True + elif _eq(observed, self.base_img_superpixels_left, 2): + seen["left"] = True + elif _eq(observed, self.base_img_superpixels_right, 2): + seen["right"] = True + elif _eq(observed, self.base_img_superpixels, 2): + seen["both"] = True + else: + raise Exception( + "Generated superpixels image does not match " + "any expected image." + ) + if np.all(seen.values()): + break + assert np.all(seen.values()) def test_failure_on_invalid_datatype_for_p_replace(self): # note that assertRaisesRegex does not exist in 2.7 @@ -152,15 +193,16 @@ def test_zero_sized_axes(self): (1, 0, 1) ] - for shape in shapes: - with self.subTest(shape=shape): - image = np.full(shape, 128, dtype=np.uint8) - aug = iaa.Superpixels(p_replace=1.0, n_segments=10) + for shape, use_np_replace in itertools.product(shapes, [True, False]): + with self.subTest(shape=shape, use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + image = np.full(shape, 128, dtype=np.uint8) + aug = iaa.Superpixels(p_replace=1.0, n_segments=10) - image_aug = aug(image=image) + image_aug = aug(image=image) - assert image_aug.dtype.name == "uint8" - assert image_aug.shape == shape + assert image_aug.dtype.name == "uint8" + assert image_aug.shape == shape def test_unusual_channel_numbers(self): shapes = [ @@ -170,15 +212,16 @@ def test_unusual_channel_numbers(self): (1, 1, 513) ] - for shape in shapes: - with self.subTest(shape=shape): - image = np.full(shape, 128, dtype=np.uint8) - aug = iaa.Superpixels(p_replace=1.0, n_segments=10) + for shape, use_np_replace in itertools.product(shapes, [True, False]): + with self.subTest(shape=shape, use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + image = np.full(shape, 128, dtype=np.uint8) + aug = iaa.Superpixels(p_replace=1.0, n_segments=10) - image_aug = aug(image=image) + image_aug = aug(image=image) - assert image_aug.dtype.name == "uint8" - assert image_aug.shape == shape + assert image_aug.dtype.name == "uint8" + assert image_aug.shape == shape def test_get_parameters(self): aug = iaa.Superpixels( @@ -193,62 +236,73 @@ def test_get_parameters(self): assert params[3] == "nearest" def test_other_dtypes_bool(self): - aug = iaa.Superpixels(p_replace=1.0, n_segments=2) - img = np.array([ - [False, False, True, True], - [False, False, True, True] - ], dtype=bool) - img_aug = aug.augment_image(img) - assert img_aug.dtype == img.dtype - assert np.all(img_aug == img) - - aug = iaa.Superpixels(p_replace=1.0, n_segments=1) - img = np.array([ - [True, True, True, True], - [False, True, True, True] - ], dtype=bool) - img_aug = aug.augment_image(img) - assert img_aug.dtype == img.dtype - assert np.all(img_aug) + for use_np_replace in [True, False]: + with self.subTest(use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + aug = iaa.Superpixels(p_replace=1.0, n_segments=2) + img = np.array([ + [False, False, True, True], + [False, False, True, True] + ], dtype=bool) + img_aug = aug.augment_image(img) + assert img_aug.dtype == img.dtype + assert np.all(img_aug == img) + + aug = iaa.Superpixels(p_replace=1.0, n_segments=1) + img = np.array([ + [True, True, True, True], + [False, True, True, True] + ], dtype=bool) + img_aug = aug.augment_image(img) + assert img_aug.dtype == img.dtype + assert np.all(img_aug) def test_other_dtypes_uint_int(self): - for dtype in [np.uint8, np.uint16, np.uint32, - np.int8, np.int16, np.int32]: - min_value, center_value, max_value = \ - iadt.get_value_range_of_dtype(dtype) - - if np.dtype(dtype).kind == "i": - values = [int(center_value), int(0.1 * max_value), - int(0.2 * max_value), int(0.5 * max_value), - max_value-100] - values = [((-1)*value, value) for value in values] - else: - values = [(0, int(center_value)), - (10, int(0.1 * max_value)), - (10, int(0.2 * max_value)), - (10, int(0.5 * max_value)), - (0, max_value), - (int(center_value), - max_value)] - - for v1, v2 in values: - aug = iaa.Superpixels(p_replace=1.0, n_segments=2) - img = np.array([ - [v1, v1, v2, v2], - [v1, v1, v2, v2] - ], dtype=dtype) - img_aug = aug.augment_image(img) - assert img_aug.dtype == np.dtype(dtype) - assert np.array_equal(img_aug, img) - - aug = iaa.Superpixels(p_replace=1.0, n_segments=1) - img = np.array([ - [v2, v2, v2, v2], - [v1, v2, v2, v2] - ], dtype=dtype) - img_aug = aug.augment_image(img) - assert img_aug.dtype == np.dtype(dtype) - assert np.all(img_aug == int(np.round((7/8)*v2 + (1/8)*v1))) + dtypes = ["uint8", "uint16", "uint32", + "int8", "int16", "int32"] + for dtype in dtypes: + for use_np_replace in [True, False]: + with self.subTest(dtype=dtype, use_np_replace=use_np_replace): + with _create_replace_np_context(use_np_replace): + dtype = np.dtype(dtype) + + min_value, center_value, max_value = \ + iadt.get_value_range_of_dtype(dtype) + + if np.dtype(dtype).kind == "i": + values = [ + int(center_value), int(0.1 * max_value), + int(0.2 * max_value), int(0.5 * max_value), + max_value-100 + ] + values = [((-1)*value, value) for value in values] + else: + values = [(0, int(center_value)), + (10, int(0.1 * max_value)), + (10, int(0.2 * max_value)), + (10, int(0.5 * max_value)), + (0, max_value), + (int(center_value), + max_value)] + + for v1, v2 in values: + aug = iaa.Superpixels(p_replace=1.0, n_segments=2) + img = np.array([ + [v1, v1, v2, v2], + [v1, v1, v2, v2] + ], dtype=dtype) + img_aug = aug.augment_image(img) + assert img_aug.dtype.name == dtype.name + assert np.array_equal(img_aug, img) + + aug = iaa.Superpixels(p_replace=1.0, n_segments=1) + img = np.array([ + [v2, v2, v2, v2], + [v1, v2, v2, v2] + ], dtype=dtype) + img_aug = aug.augment_image(img) + assert img_aug.dtype.name == dtype.name + assert np.all(img_aug == int((7/8)*v2 + (1/8)*v1)) def test_other_dtypes_float(self): # currently, no float dtype is actually accepted