Skip to content

Commit

Permalink
add mosaic augment from YOLOv5
Browse files Browse the repository at this point in the history
  • Loading branch information
david8862 committed May 30, 2021
1 parent 4f5b7ea commit 42622f7
Show file tree
Hide file tree
Showing 6 changed files with 189 additions and 11 deletions.
157 changes: 156 additions & 1 deletion common/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def random_hsv_distort(image, hue=.1, sat=1.5, val=1.5):
usually for training data preprocess
# Arguments
image: origin image to be resize
image: origin image for HSV distort
PIL Image object containing image data
hue: distort range for Hue
scalar
Expand Down Expand Up @@ -761,6 +761,161 @@ def get_mosaic_samples():
return new_images, new_boxes


def random_mosaic_augment_v5(image_data, boxes_data, prob=.1):
"""
Random mosaic augment from YOLOv5 implementation
reference:
https://github.com/ultralytics/yolov5/blob/develop/utils/datasets.py
# Arguments
image_data: origin images for mosaic augment
numpy array for normalized batch image data
boxes_data: origin bboxes for mosaic augment
numpy array for batch bboxes
prob: probability for augment ,
scalar to control the augment probability.
# Returns
image_data: augmented batch image data.
boxes_data: augmented batch bboxes data.
"""
do_augment = rand() < prob
if not do_augment:
return image_data, boxes_data
else:
batch_size = len(image_data)
assert batch_size >= 4, 'mosaic augment need batch size >= 4'

def get_mosaic_samples():
# random select 4 images from batch as mosaic samples
random_index = random.sample(list(range(batch_size)), 4)

random_images = []
random_bboxes = []
for idx in random_index:
random_images.append(image_data[idx])
random_bboxes.append(boxes_data[idx])
return random_images, np.array(random_bboxes)

new_images = []
new_boxes = []
input_height, input_width, input_channel = image_data[0].shape[:3]

#each batch has batch_size images, so we also need to
#generate batch_size mosaic images
for j in range(batch_size):
images, bboxes = get_mosaic_samples()

# mosaic center x, y
mosaic_border = (input_width//2, input_height//2)
x_center = int(random.uniform(mosaic_border[0], input_width*2-mosaic_border[0]))
y_center = int(random.uniform(mosaic_border[1], input_height*2-mosaic_border[1]))

# create large mosaic image with size (input_height*2, input_width*2)
mosaic_image = np.full((input_height*2, input_width*2, input_channel), 128, dtype=np.uint8)
mosaic_bbox = []
max_boxes = bboxes.shape[1]

for i in range(4):
image = images[i]
bbox = bboxes[i]
height, width = image.shape[:2]

# calculate padding area in each src & target image
if i == 0: # top left
#x1a, y1a, x2a, y2a = max(xc - w, 0), max(yc - h, 0), xc, yc
#x1b, y1b, x2b, y2b = w - (x2a - x1a), h - (y2a - y1a), w, h
xmin_target = max(x_center - width, 0)
ymin_target = max(y_center - height, 0)
xmax_target = x_center
ymax_target = y_center

xmin_src = width - (xmax_target - xmin_target)
ymin_src = height - (ymax_target - ymin_target)
xmax_src = width
ymax_src = height
elif i == 1: # top right
#x1a, y1a, x2a, y2a = xc, max(yc - h, 0), min(xc + w, s * 2), yc
#x1b, y1b, x2b, y2b = 0, h - (y2a - y1a), min(w, x2a - x1a), h
xmin_target = x_center
ymin_target = max(y_center - height, 0)
xmax_target = min(x_center + width, width * 2)
ymax_target = y_center

xmin_src = 0
ymin_src = height - (ymax_target - ymin_target)
xmax_src = min(width, xmax_target - xmin_target)
ymax_src = height
elif i == 2: # bottom left
#x1a, y1a, x2a, y2a = max(xc - w, 0), yc, xc, min(s * 2, yc + h)
#x1b, y1b, x2b, y2b = w - (x2a - x1a), 0, w, min(y2a - y1a, h)
xmin_target = max(x_center - width, 0)
ymin_target = y_center
xmax_target = x_center
ymax_target = min(height * 2, y_center + height)

xmin_src = width - (xmax_target - xmin_target)
ymin_src = 0
xmax_src = width
ymax_src = min(ymax_target - ymin_target, height)
elif i == 3: # bottom right
#x1a, y1a, x2a, y2a = xc, yc, min(xc + w, s * 2), min(s * 2, yc + h)
#x1b, y1b, x2b, y2b = 0, 0, min(w, x2a - x1a), min(y2a - y1a, h)
xmin_target = x_center
ymin_target = y_center
xmax_target = min(x_center + width, width * 2)
ymax_target = min(height * 2, y_center + height)

xmin_src = 0
ymin_src = 0
xmax_src = min(width, xmax_target - xmin_target)
ymax_src = min(ymax_target - ymin_target, height)

# padding src image to corresponding mosaic area
mosaic_image[ymin_target:ymax_target, xmin_target:xmax_target] = image[ymin_src:ymax_src, xmin_src:xmax_src]
# padding width & height for bbox
padding_width = xmin_target - xmin_src
padding_height = ymin_target - ymin_src

# adjust bbox to new mosaic image, with padding width & height
for box in bbox:
# break loop when reach invalid box line
if (box[:4] == 0).all():
break
x_min, y_min, x_max, y_max = box[0], box[1], box[2], box[3]
x_min += padding_width
y_min += padding_height
x_max += padding_width
y_max += padding_height

mosaic_bbox.append([x_min, y_min, x_max, y_max, box[4]])

if len(mosaic_bbox) > max_boxes:
mosaic_bbox = mosaic_bbox[:max_boxes]

box_data = np.zeros((max_boxes, 5))
if len(mosaic_bbox) > 0:
box_data[:len(mosaic_bbox)] = mosaic_bbox

# clip boxes to valid image area
np.clip(box_data[..., 0], 0, input_width*2-1, out=box_data[..., 0])
np.clip(box_data[..., 1], 0, input_height*2-1, out=box_data[..., 1])
np.clip(box_data[..., 2], 0, input_width*2-1, out=box_data[..., 2])
np.clip(box_data[..., 3], 0, input_height*2-1, out=box_data[..., 3])

# resize image & box back to input shape
mosaic_image = cv2.resize(mosaic_image, (input_width, input_height), cv2.INTER_AREA)
box_data[..., :4] //= 2

new_images.append(mosaic_image)
new_boxes.append(box_data)

new_images = np.stack(new_images)
new_boxes = np.array(new_boxes)
return new_images, new_boxes


def merge_cutmix_bboxes(bboxes, cut_xmin, cut_ymin, cut_xmax, cut_ymax, image_size):
# adjust & merge cutmix samples bboxes as following area order:
# -----------------
Expand Down
15 changes: 10 additions & 5 deletions tools/misc/augment_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
test enhace data argument functions (mosaic/cutmix)
test enhace data argument functions (mosaic/mosaic_v5/cutmix)
"""
import os, sys, argparse
import numpy as np
Expand All @@ -11,7 +11,7 @@
sys.path.append(os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', '..'))
from yolo3.data import get_ground_truth_data
from common.utils import get_dataset, get_classes, draw_label
from common.data_utils import random_mosaic_augment, random_cutmix_augment
from common.data_utils import random_mosaic_augment, random_mosaic_augment_v5, random_cutmix_augment


def draw_boxes(images, boxes, class_names, output_path):
Expand Down Expand Up @@ -44,7 +44,7 @@ def main():
parser.add_argument('--output_path', type=str, required=False, help='output path for augmented images, default=%(default)s', default='./test')
parser.add_argument('--batch_size', type=int, required=False, help = "batch size for test data, default=%(default)s", default=16)
parser.add_argument('--model_image_size', type=str, required=False, help='model image input size as <height>x<width>, default=%(default)s', default='416x416')
parser.add_argument('--augment_type', type=str, required=False, help = "enhance data augmentation type (mosaic/cutmix), default=%(default)s", default='mosaic', choices=['mosaic', 'cutmix'])
parser.add_argument('--enhance_augment', type=str, required=False, help = "enhance data augmentation type, default=%(default)s", default=None, choices=['mosaic', 'mosaic_v5', 'cutmix', None])

args = parser.parse_args()
class_names = get_classes(args.classes_path)
Expand All @@ -69,14 +69,19 @@ def main():
image_data = np.array(image_data)
boxes_data = np.array(boxes_data)

if args.augment_type == 'mosaic':
if args.enhance_augment == 'mosaic':
image_data, boxes_data = random_mosaic_augment(image_data, boxes_data, prob=1)
elif args.augment_type == 'cutmix':
elif args.enhance_augment == 'mosaic_v5':
image_data, boxes_data = random_mosaic_augment_v5(image_data, boxes_data, prob=1)
elif args.enhance_augment == 'cutmix':
image_data, boxes_data = random_cutmix_augment(image_data, boxes_data, prob=1)
elif args.enhance_augment == None:
print('No enhance augment type. Will only apply base augment')
else:
raise ValueError('Unsupported augment type')

draw_boxes(image_data, boxes_data, class_names, args.output_path)
print('Done. augment images have been saved in', args.output_path)


if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,8 +296,8 @@ def main(args):
help='Whether to use multiscale training')
parser.add_argument('--rescale_interval', type=int, required=False, default=10,
help = "Number of iteration(batches) interval to rescale input size, default=%(default)s")
parser.add_argument('--enhance_augment', type=str, required=False, default=None, choices=[None, 'mosaic'],
help = "enhance data augmentation type (None/mosaic), default=%(default)s")
parser.add_argument('--enhance_augment', type=str, required=False, default=None, choices=[None, 'mosaic', 'mosaic_v5'],
help = "enhance data augmentation type (None/mosaic/mosaic_v5), default=%(default)s")
parser.add_argument('--label_smoothing', type=float, required=False, default=0,
help = "Label smoothing factor (between 0 and 1) for classification loss, default=%(default)s")
parser.add_argument('--multi_anchor_assign', default=False, action="store_true",
Expand Down
8 changes: 7 additions & 1 deletion yolo2/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random, math
from PIL import Image
from tensorflow.keras.utils import Sequence
from common.data_utils import random_mosaic_augment
from common.data_utils import random_mosaic_augment, random_mosaic_augment_v5
from common.utils import get_multiscale_list
from yolo3.data import get_ground_truth_data

Expand Down Expand Up @@ -196,6 +196,9 @@ def __getitem__(self, index):
if self.enhance_augment == 'mosaic':
# add random mosaic augment on batch ground truth data
image_data, box_data = random_mosaic_augment(image_data, box_data, prob=0.2)
elif self.enhance_augment == 'mosaic_v5':
# mosaic augment from YOLOv5
image_data, box_data = random_mosaic_augment_v5(image_data, box_data, prob=0.2)

y_true_data = get_y_true_data(box_data, self.anchors, self.input_shape, self.num_classes, self.multi_anchor_assign)

Expand Down Expand Up @@ -237,6 +240,9 @@ def yolo2_data_generator(annotation_lines, batch_size, input_shape, anchors, num
if enhance_augment == 'mosaic':
# add random mosaic augment on batch ground truth data
image_data, box_data = random_mosaic_augment(image_data, box_data, prob=0.2)
elif enhance_augment == 'mosaic_v5':
# mosaic augment from YOLOv5
image_data, box_data = random_mosaic_augment_v5(image_data, box_data, prob=0.2)

y_true_data = get_y_true_data(box_data, anchors, input_shape, num_classes, multi_anchor_assign)

Expand Down
8 changes: 7 additions & 1 deletion yolo3/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random, math
from PIL import Image
from tensorflow.keras.utils import Sequence
from common.data_utils import normalize_image, letterbox_resize, random_resize_crop_pad, reshape_boxes, random_hsv_distort, random_horizontal_flip, random_vertical_flip, random_grayscale, random_brightness, random_chroma, random_contrast, random_sharpness, random_blur, random_motion_blur, random_rotate, random_gridmask, random_mosaic_augment
from common.data_utils import normalize_image, letterbox_resize, random_resize_crop_pad, reshape_boxes, random_hsv_distort, random_horizontal_flip, random_vertical_flip, random_grayscale, random_brightness, random_chroma, random_contrast, random_sharpness, random_blur, random_motion_blur, random_rotate, random_gridmask, random_mosaic_augment, random_mosaic_augment_v5
from common.utils import get_multiscale_list


Expand Down Expand Up @@ -232,6 +232,9 @@ def __getitem__(self, index):
if self.enhance_augment == 'mosaic':
# add random mosaic augment on batch ground truth data
image_data, box_data = random_mosaic_augment(image_data, box_data, prob=0.2)
elif self.enhance_augment == 'mosaic_v5':
# mosaic augment from YOLOv5
image_data, box_data = random_mosaic_augment_v5(image_data, box_data, prob=0.2)

y_true = preprocess_true_boxes(box_data, self.input_shape, self.anchors, self.num_classes, self.multi_anchor_assign)

Expand Down Expand Up @@ -273,6 +276,9 @@ def yolo3_data_generator(annotation_lines, batch_size, input_shape, anchors, num
if enhance_augment == 'mosaic':
# add random mosaic augment on batch ground truth data
image_data, box_data = random_mosaic_augment(image_data, box_data, prob=0.2)
elif enhance_augment == 'mosaic_v5':
# mosaic augment from YOLOv5
image_data, box_data = random_mosaic_augment_v5(image_data, box_data, prob=0.2)

y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes, multi_anchor_assign)
yield [image_data, *y_true], np.zeros(batch_size)
Expand Down
8 changes: 7 additions & 1 deletion yolo5/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random, math
from PIL import Image
from tensorflow.keras.utils import Sequence
from common.data_utils import random_mosaic_augment
from common.data_utils import random_mosaic_augment, random_mosaic_augment_v5
from common.utils import get_multiscale_list
from yolo3.data import get_ground_truth_data

Expand Down Expand Up @@ -159,6 +159,9 @@ def __getitem__(self, index):
if self.enhance_augment == 'mosaic':
# add random mosaic augment on batch ground truth data
image_data, box_data = random_mosaic_augment(image_data, box_data, prob=0.2)
elif self.enhance_augment == 'mosaic_v5':
# mosaic augment from YOLOv5
image_data, box_data = random_mosaic_augment_v5(image_data, box_data, prob=0.2)

y_true = preprocess_true_boxes(box_data, self.input_shape, self.anchors, self.num_classes, self.multi_anchor_assign)

Expand Down Expand Up @@ -200,6 +203,9 @@ def yolo5_data_generator(annotation_lines, batch_size, input_shape, anchors, num
if enhance_augment == 'mosaic':
# add random mosaic augment on batch ground truth data
image_data, box_data = random_mosaic_augment(image_data, box_data, prob=0.2)
elif enhance_augment == 'mosaic_v5':
# mosaic augment from YOLOv5
image_data, box_data = random_mosaic_augment_v5(image_data, box_data, prob=0.2)

y_true = preprocess_true_boxes(box_data, input_shape, anchors, num_classes, multi_anchor_assign)
yield [image_data, *y_true], np.zeros(batch_size)
Expand Down

0 comments on commit 42622f7

Please sign in to comment.