Skip to content

refactor data augmentation pipeline #105

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 17 additions & 20 deletions configs/det/db_r50_icdar15.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ loss_scaler:

train:
ckpt_save_dir: './tmp_det'
dataset_sink_mode: False
dataset_sink_mode: True
dataset:
type: DetDataset
dataset_root: /Users/Samit/Data/datasets
Expand All @@ -75,15 +75,14 @@ train:
img_mode: RGB
to_float32: False
- DetLabelEncode:
- MZRandomScaleByShortSide:
short_side: 736
- RandomScale:
scale_range: [ 1.022, 3.0 ]
- IaaAugment:
augmenter_args:
- { 'type': 'Affine', 'args': { 'rotate': [ -10, 10 ] } }
- { 'type': 'Fliplr', 'args': { 'p': 0.5 } }
- MZRandomCropData:
max_tries: 100
min_crop_side_ratio: 0.1
Affine: { rotate: [ -10, 10 ] }
Fliplr: { p: 0.5 }
- RandomCropWithBBox:
max_tries: 10
min_crop_ratio: 0.1
crop_size: [ 640, 640 ]
- ShrinkBinaryMap:
min_text_size: 8
Expand All @@ -92,17 +91,16 @@ train:
shrink_ratio: 0.4
thresh_min: 0.3
thresh_max: 0.7
- MZRandomColorAdjust:
brightness: 0.1255 #32.0 / 255
- RandomColorAdjust:
brightness: 0.1255 # 32.0 / 255
saturation: 0.5
to_numpy: True
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
mean: imagenet
std: imagenet
- ToCHWImage:
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visaulize
# the order of the dataloader list, matching the network input and the input labels for the loss function, and optional data for debug/visualize
output_columns: [ 'image', 'binary_map', 'mask', 'thresh_map', 'thresh_mask' ] #'img_path']
# output_columns: ['image'] # for debug op performance
num_columns_to_net: 1 # num inputs for network forward func in output_columns
Expand All @@ -128,13 +126,12 @@ eval:
img_mode: RGB
to_float32: False
- DetLabelEncode:
- MZResizeByGrid:
divisor: 32
transform_polys: True
# MZResizeByGrid already sets the evaluation size to [ 736, 1280 ].
# Uncomment MZScalePad block for other resolutions.
# - MZScalePad:
# eval_size: [ 736, 1280 ] # h, w
- GridResize:
factor: 32
# GridResize already sets the evaluation size to [ 736, 1280 ].
# Uncomment ScalePadImage block for other resolutions.
# - ScalePadImage:
# target_size: [ 736, 1280 ] # h, w
- NormalizeImage:
bgr_to_rgb: False
is_hwc: True
Expand Down
180 changes: 177 additions & 3 deletions mindocr/data/transforms/general_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@
import cv2
import numpy as np
from PIL import Image
from mindspore.dataset.vision import RandomColorAdjust as MSRandomColorAdjust

from mindcv.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD

__all__ = ['DecodeImage', 'NormalizeImage', 'ToCHWImage', 'PackLoaderInputs']
__all__ = ['DecodeImage', 'NormalizeImage', 'ToCHWImage', 'PackLoaderInputs', 'ScalePadImage', 'GridResize',
'RandomScale', 'RandomCropWithBBox', 'RandomColorAdjust']


# TODO: use mindspore C.decode for efficiency
class DecodeImage(object):
class DecodeImage:
"""
img_mode (str): The channel order of the output, 'BGR' and 'RGB'. Default to 'BGR'.
channel_first (bool): if True, image shpae is CHW. If False, HWC. Default to False
Expand Down Expand Up @@ -118,4 +120,176 @@ def __call__(self, data):
assert k in data, f'key {k} does not exists in data, availabe keys are {data.keys()}'
out.append(data[k])

return tuple(out)
return tuple(out)


class ScalePadImage:
"""
Scale image and polys by the shorter side, then pad to the target_size.
input image format: hwc

Args:
target_size: [H, W] of the output image.
"""
def __init__(self, target_size: list):
self._target_size = np.array(target_size)

def __call__(self, data: dict):
"""
required keys:
image, HWC
(polys)
modified keys:
image
(polys)
added keys:
shape: [src_h, src_w, scale_ratio_h, scale_ratio_w]
"""
size = np.array(data['image'].shape[:2])
scale = min(self._target_size / size)
new_size = np.round(scale * size).astype(np.int)

data['image'] = cv2.resize(data['image'], new_size[::-1])
data['image'] = np.pad(data['image'],
(*tuple((0, ts - ns) for ts, ns in zip(self._target_size, new_size)), (0, 0)))

if 'polys' in data:
data['polys'] *= scale

data['shape'] = np.concatenate((size, np.array([scale, scale])), dtype=np.float32)
return data


class GridResize:
"""
Resize image to make it divisible by a specified factor exactly.
Resize polygons correspondingly, if provided.
"""
def __init__(self, factor: int = 32):
self._factor = factor

def __call__(self, data: dict):
"""
required keys:
image, HWC
(polys)
modified keys:
image
(polys)
"""
size = np.array(data['image'].shape[:2])
scale = np.ceil(size / self._factor) * self._factor / size
data['image'] = cv2.resize(data['image'], None, fx=scale[1], fy=scale[0])

if 'polys' in data:
data['polys'] *= scale[::-1] # w, h order
return data


class RandomScale:
"""
Randomly scales an image and its polygons in a predefined scale range.
Args:
scale_range: (min, max) scale range.
"""
def __init__(self, scale_range: Union[tuple, list]):
self._range = scale_range

def __call__(self, data: dict):
"""
required keys:
image, HWC
(polys)
modified keys:
image
(polys)
"""
scale = np.random.uniform(*self._range)
data['image'] = cv2.resize(data['image'], dsize=None, fx=scale, fy=scale)

if 'polys' in data:
data['polys'] *= scale
return data


class RandomCropWithBBox:
"""
Randomly cuts a crop from an image along with polygons.

Args:
max_tries: number of attempts to try to cut a crop with a polygon in it.
min_crop_ratio: minimum size of a crop in respect to an input image size.
crop_size: target size of the crop (resized and padded, if needed), preserves sides ratio.
"""
def __init__(self, max_tries=10, min_crop_ratio=0.1, crop_size=(640, 640)):
self._crop_size = crop_size
self._ratio = min_crop_ratio
self._max_tries = max_tries

def __call__(self, data):
start, end = self._find_crop(data)
scale = min(self._crop_size / (end - start))

data['image'] = cv2.resize(data['image'][start[0]: end[0], start[1]: end[1]], None, fx=scale, fy=scale)
data['image'] = np.pad(data['image'],
(*tuple((0, cs - ds) for cs, ds in zip(self._crop_size, data['image'].shape[:2])), (0, 0)))

start, end = start[::-1], end[::-1] # convert to x, y coord
new_polys, new_texts, new_ignores = [], [], []
for _id in range(len(data['polys'])):
# if the polygon is within the crop
if (data['polys'][_id].max(axis=0) > start).all() and (data['polys'][_id].min(axis=0) < end).all(): # NOQA
new_polys.append((data['polys'][_id] - start) * scale)
new_texts.append(data['texts'][_id])
new_ignores.append(data['ignore_tags'][_id])

data['polys'] = np.array(new_polys) if isinstance(data['polys'], np.ndarray) else new_polys
data['texts'] = new_texts
data['ignore_tags'] = new_ignores

return data

def _find_crop(self, data):
size = np.array(data['image'].shape[:2])
polys = [poly for poly, ignore in zip(data['polys'], data['ignore_tags']) if not ignore]

if polys:
# do not crop through polys => find available coordinates
h_array, w_array = np.zeros(size[0], dtype=np.int32), np.zeros(size[1], dtype=np.int32)
for poly in polys:
points = np.maximum(np.round(poly).astype(np.int32), 0)
w_array[points[:, 0].min(): points[:, 0].max() + 1] = 1
h_array[points[:, 1].min(): points[:, 1].max() + 1] = 1
# find available coordinates that don't include text
h_avail = np.where(h_array == 0)[0]
w_avail = np.where(w_array == 0)[0]

min_size = np.ceil(size * self._ratio).astype(np.int32)
for _ in range(self._max_tries):
y = np.sort(np.random.choice(h_avail, size=2))
x = np.sort(np.random.choice(w_avail, size=2))
start, end = np.array([y[0], x[0]]), np.array([y[1], x[1]])

if ((end - start) < min_size).any(): # NOQA
continue

# check that at least one polygon is within the crop
for poly in polys:
if (poly.max(axis=0) > start[::-1]).all() and (poly.min(axis=0) < end[::-1]).all(): # NOQA
return start, end

# failed to generate a crop or all polys are marked as ignored
return np.array([0, 0]), size


class RandomColorAdjust:
def __init__(self, brightness=32.0 / 255, saturation=0.5):
self._jitter = MSRandomColorAdjust(brightness=brightness, saturation=saturation)

def __call__(self, data):
"""
required keys: image
modified keys: image
"""
data['image'] = self._jitter(data['image'])
return data
82 changes: 12 additions & 70 deletions mindocr/data/transforms/iaa_augment.py
Original file line number Diff line number Diff line change
@@ -1,83 +1,25 @@
"""
iaa transform
"""
import numpy as np
import imgaug
import imgaug.augmenters as iaa

__all__ = ['IaaAugment']

def compose_iaa_transforms(args, root=True):
if (args is None) or (len(args) == 0):
return None
elif isinstance(args, list):
if root:
sequence = [compose_iaa_transforms(value, root=False) for value in args]
return iaa.Sequential(sequence)
else:
return getattr(iaa, args[0])( *[_list_to_tuple(a) for a in args[1:]])
elif isinstance(args, dict):
cls = getattr(iaa, args['type'])
return cls(**{
k: _list_to_tuple(v)
for k, v in args['args'].items()
})
else:
raise ValueError('Unknown augmenter arg: ' + str(args))


def _list_to_tuple(obj):
if isinstance(obj, list):
return tuple(obj)
return obj


class IaaAugment():
def __init__(self, augmenter_args=None, **kwargs):
if augmenter_args is None:
augmenter_args = [{
'type': 'Fliplr',
'args': {
'p': 0.5
}
}, {
'type': 'Affine',
'args': {
'rotate': [-10, 10]
}
}, {
'type': 'Resize',
'args': {
'size': [0.5, 3]
}
}]
self.augmenter = compose_iaa_transforms(augmenter_args)
class IaaAugment:
def __init__(self, **augments):
self._augmenter = iaa.Sequential([getattr(iaa, aug)(**args) for aug, args in augments.items()])

def __call__(self, data):
image = data['image']
shape = image.shape
aug = self._augmenter.to_deterministic() # to augment an image and its keypoints identically
data['image'] = aug.augment_image(data['image'])

if self.augmenter:
aug = self.augmenter.to_deterministic()
data['image'] = aug.augment_image(image)
data = self.may_augment_annotation(aug, data, shape)
return data
if 'polys' in data:
new_polys = []
for poly in data['polys']:
kps = imgaug.KeypointsOnImage([imgaug.Keypoint(p[0], p[1]) for p in poly], shape=data['image'].shape)
kps = aug.augment_keypoints(kps)
new_polys.append(np.array([[kp.x, kp.y] for kp in kps.keypoints]))

def may_augment_annotation(self, aug, data, shape):
if aug is None:
return data
data['polys'] = np.array(new_polys) if isinstance(data['polys'], np.ndarray) else new_polys

line_polys = []
for poly in data['polys']:
new_poly = self.may_augment_poly(aug, shape, poly)
line_polys.append(new_poly)
data['polys'] = np.array(line_polys)
return data

def may_augment_poly(self, aug, img_shape, poly):
keypoints = [imgaug.Keypoint(p[0], p[1]) for p in poly]
keypoints = aug.augment_keypoints(
[imgaug.KeypointsOnImage(
keypoints, shape=img_shape)])[0].keypoints
poly = [(p.x, p.y) for p in keypoints]
return poly
Loading