Skip to content

Commit

Permalink
[Feature] Add Mosaic transform (open-mmlab#1093)
Browse files Browse the repository at this point in the history
* Fix typo in usage example

* original mosaic code in mmdet

* Adjust mosaic to the semantic segmentation

* Remove bbox test in test_mosaic

* Add unittests

* Fix resize mode for seg_fields

* Fix repr error

* modify Mosaic docs

* modify from Mosaic to RandomMosaic

* Add docstring

* modify Mosaic docstring

* [Docs] Add a blank line before Returns:

* add blank lines

Co-authored-by: MeowZheng <meowzheng@outlook.com>
  • Loading branch information
lkm2835 and MeowZheng authored Jan 11, 2022
1 parent 44a9635 commit f0262fa
Show file tree
Hide file tree
Showing 3 changed files with 322 additions and 3 deletions.
7 changes: 4 additions & 3 deletions mmseg/datasets/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,14 @@
from .test_time_aug import MultiScaleFlipAug
from .transforms import (CLAHE, AdjustGamma, Normalize, Pad,
PhotoMetricDistortion, RandomCrop, RandomCutOut,
RandomFlip, RandomRotate, Rerange, Resize, RGB2Gray,
SegRescale)
RandomFlip, RandomMosaic, RandomRotate, Rerange,
Resize, RGB2Gray, SegRescale)

__all__ = [
'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
'Transpose', 'Collect', 'LoadAnnotations', 'LoadImageFromFile',
'MultiScaleFlipAug', 'Resize', 'RandomFlip', 'Pad', 'RandomCrop',
'Normalize', 'SegRescale', 'PhotoMetricDistortion', 'RandomRotate',
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut'
'AdjustGamma', 'CLAHE', 'Rerange', 'RGB2Gray', 'RandomCutOut',
'RandomMosaic'
]
269 changes: 269 additions & 0 deletions mmseg/datasets/pipelines/transforms.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy

import mmcv
import numpy as np
from mmcv.utils import deprecated_api_warning, is_tuple_of
Expand Down Expand Up @@ -1040,3 +1042,270 @@ def __repr__(self):
repr_str += f'fill_in={self.fill_in}, '
repr_str += f'seg_fill_in={self.seg_fill_in})'
return repr_str


@PIPELINES.register_module()
class RandomMosaic(object):
"""Mosaic augmentation. Given 4 images, mosaic transform combines them into
one output image. The output image is composed of the parts from each sub-
image.
.. code:: text
mosaic transform
center_x
+------------------------------+
| pad | pad |
| +-----------+ |
| | | |
| | image1 |--------+ |
| | | | |
| | | image2 | |
center_y |----+-------------+-----------|
| | cropped | |
|pad | image3 | image4 |
| | | |
+----|-------------+-----------+
| |
+-------------+
The mosaic transform steps are as follows:
1. Choose the mosaic center as the intersections of 4 images
2. Get the left top image according to the index, and randomly
sample another 3 images from the custom dataset.
3. Sub image will be cropped if image is larger than mosaic patch
Args:
prob (float): mosaic probability.
img_scale (Sequence[int]): Image size after mosaic pipeline of
a single image. The size of the output image is four times
that of a single image. The output image comprises 4 single images.
Default: (640, 640).
center_ratio_range (Sequence[float]): Center ratio range of mosaic
output. Default: (0.5, 1.5).
pad_val (int): Pad value. Default: 0.
seg_pad_val (int): Pad value of segmentation map. Default: 255.
"""

def __init__(self,
prob,
img_scale=(640, 640),
center_ratio_range=(0.5, 1.5),
pad_val=0,
seg_pad_val=255):
assert 0 <= prob and prob <= 1
assert isinstance(img_scale, tuple)
self.prob = prob
self.img_scale = img_scale
self.center_ratio_range = center_ratio_range
self.pad_val = pad_val
self.seg_pad_val = seg_pad_val

def __call__(self, results):
"""Call function to make a mosaic of image.
Args:
results (dict): Result dict.
Returns:
dict: Result dict with mosaic transformed.
"""
mosaic = True if np.random.rand() < self.prob else False
if mosaic:
results = self._mosaic_transform_img(results)
results = self._mosaic_transform_seg(results)
return results

def get_indexes(self, dataset):
"""Call function to collect indexes.
Args:
dataset (:obj:`MultiImageMixDataset`): The dataset.
Returns:
list: indexes.
"""

indexes = [random.randint(0, len(dataset)) for _ in range(3)]
return indexes

def _mosaic_transform_img(self, results):
"""Mosaic transform function.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""

assert 'mix_results' in results
if len(results['img'].shape) == 3:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2), 3),
self.pad_val,
dtype=results['img'].dtype)
else:
mosaic_img = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.pad_val,
dtype=results['img'].dtype)

# mosaic center x, y
self.center_x = int(
random.uniform(*self.center_ratio_range) * self.img_scale[1])
self.center_y = int(
random.uniform(*self.center_ratio_range) * self.img_scale[0])
center_position = (self.center_x, self.center_y)

loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])

img_i = result_patch['img']
h_i, w_i = img_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
img_i = mmcv.imresize(
img_i, (int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)))

# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, img_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord

# crop and paste image
mosaic_img[y1_p:y2_p, x1_p:x2_p] = img_i[y1_c:y2_c, x1_c:x2_c]

results['img'] = mosaic_img
results['img_shape'] = mosaic_img.shape
results['ori_shape'] = mosaic_img.shape

return results

def _mosaic_transform_seg(self, results):
"""Mosaic transform function for label annotations.
Args:
results (dict): Result dict.
Returns:
dict: Updated result dict.
"""

assert 'mix_results' in results
for key in results.get('seg_fields', []):
mosaic_seg = np.full(
(int(self.img_scale[0] * 2), int(self.img_scale[1] * 2)),
self.seg_pad_val,
dtype=results[key].dtype)

# mosaic center x, y
center_position = (self.center_x, self.center_y)

loc_strs = ('top_left', 'top_right', 'bottom_left', 'bottom_right')
for i, loc in enumerate(loc_strs):
if loc == 'top_left':
result_patch = copy.deepcopy(results)
else:
result_patch = copy.deepcopy(results['mix_results'][i - 1])

gt_seg_i = result_patch[key]
h_i, w_i = gt_seg_i.shape[:2]
# keep_ratio resize
scale_ratio_i = min(self.img_scale[0] / h_i,
self.img_scale[1] / w_i)
gt_seg_i = mmcv.imresize(
gt_seg_i,
(int(w_i * scale_ratio_i), int(h_i * scale_ratio_i)),
interpolation='nearest')

# compute the combine parameters
paste_coord, crop_coord = self._mosaic_combine(
loc, center_position, gt_seg_i.shape[:2][::-1])
x1_p, y1_p, x2_p, y2_p = paste_coord
x1_c, y1_c, x2_c, y2_c = crop_coord

# crop and paste image
mosaic_seg[y1_p:y2_p, x1_p:x2_p] = gt_seg_i[y1_c:y2_c,
x1_c:x2_c]

results[key] = mosaic_seg

return results

def _mosaic_combine(self, loc, center_position_xy, img_shape_wh):
"""Calculate global coordinate of mosaic image and local coordinate of
cropped sub-image.
Args:
loc (str): Index for the sub-image, loc in ('top_left',
'top_right', 'bottom_left', 'bottom_right').
center_position_xy (Sequence[float]): Mixing center for 4 images,
(x, y).
img_shape_wh (Sequence[int]): Width and height of sub-image
Returns:
tuple[tuple[float]]: Corresponding coordinate of pasting and
cropping
- paste_coord (tuple): paste corner coordinate in mosaic image.
- crop_coord (tuple): crop corner coordinate in mosaic image.
"""

assert loc in ('top_left', 'top_right', 'bottom_left', 'bottom_right')
if loc == 'top_left':
# index0 to top left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
max(center_position_xy[1] - img_shape_wh[1], 0), \
center_position_xy[0], \
center_position_xy[1]
crop_coord = img_shape_wh[0] - (x2 - x1), img_shape_wh[1] - (
y2 - y1), img_shape_wh[0], img_shape_wh[1]

elif loc == 'top_right':
# index1 to top right part of image
x1, y1, x2, y2 = center_position_xy[0], \
max(center_position_xy[1] - img_shape_wh[1], 0), \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
center_position_xy[1]
crop_coord = 0, img_shape_wh[1] - (y2 - y1), min(
img_shape_wh[0], x2 - x1), img_shape_wh[1]

elif loc == 'bottom_left':
# index2 to bottom left part of image
x1, y1, x2, y2 = max(center_position_xy[0] - img_shape_wh[0], 0), \
center_position_xy[1], \
center_position_xy[0], \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = img_shape_wh[0] - (x2 - x1), 0, img_shape_wh[0], min(
y2 - y1, img_shape_wh[1])

else:
# index3 to bottom right part of image
x1, y1, x2, y2 = center_position_xy[0], \
center_position_xy[1], \
min(center_position_xy[0] + img_shape_wh[0],
self.img_scale[1] * 2), \
min(self.img_scale[0] * 2, center_position_xy[1] +
img_shape_wh[1])
crop_coord = 0, 0, min(img_shape_wh[0],
x2 - x1), min(y2 - y1, img_shape_wh[1])

paste_coord = x1, y1, x2, y2
return paste_coord, crop_coord

def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(prob={self.prob}, '
repr_str += f'img_scale={self.img_scale}, '
repr_str += f'center_ratio_range={self.center_ratio_range}, '
repr_str += f'pad_val={self.pad_val}, '
repr_str += f'seg_pad_val={self.pad_val})'
return repr_str
49 changes: 49 additions & 0 deletions tests/test_data/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,3 +614,52 @@ def test_cutout():
cutout_result = cutout_module(copy.deepcopy(results))
assert cutout_result['img'].sum() > img.sum()
assert cutout_result['gt_semantic_seg'].sum() > seg.sum()


def test_mosaic():
# test prob
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1.5)
build_from_cfg(transform, PIPELINES)
# test assertion for invalid img_scale
with pytest.raises(AssertionError):
transform = dict(type='RandomMosaic', prob=1, img_scale=640)
build_from_cfg(transform, PIPELINES)

results = dict()
img = mmcv.imread(
osp.join(osp.dirname(__file__), '../data/color.jpg'), 'color')
seg = np.array(
Image.open(osp.join(osp.dirname(__file__), '../data/seg.png')))

results['img'] = img
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']

transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
assert 'Mosaic' in repr(mosaic_module)

# test assertion for invalid mix_results
with pytest.raises(AssertionError):
mosaic_module(results)

results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)

results = dict()
results['img'] = img[:, :, 0]
results['gt_semantic_seg'] = seg
results['seg_fields'] = ['gt_semantic_seg']

transform = dict(type='RandomMosaic', prob=0, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
results['mix_results'] = [copy.deepcopy(results)] * 3
results = mosaic_module(results)
assert results['img'].shape[:2] == img.shape[:2]

transform = dict(type='RandomMosaic', prob=1, img_scale=(10, 12))
mosaic_module = build_from_cfg(transform, PIPELINES)
results = mosaic_module(results)
assert results['img'].shape[:2] == (20, 24)

0 comments on commit f0262fa

Please sign in to comment.