Skip to content

Commit

Permalink
add temporal transform
Browse files Browse the repository at this point in the history
  • Loading branch information
HaydenFaulkner committed Dec 11, 2019
1 parent 9eca489 commit c821f4d
Showing 1 changed file with 163 additions and 8 deletions.
171 changes: 163 additions & 8 deletions models/definitions/yolo/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import numpy as np
import mxnet as mx
from mxnet import autograd
from gluoncv.data.transforms import bbox as tbbox
from gluoncv.data.transforms import image as timage
from gluoncv.data.transforms import experimental

from ...transforms import bbox as tbbox
from ...transforms import video as tvideo


class YOLO3DefaultTrainTransform(object):
"""Default YOLO training transform which includes tons of image augmentations.
Parameters
Expand Down Expand Up @@ -137,9 +138,8 @@ def __call__(self, src, label, idx=None):
return img, bbox.astype(img.dtype)


class YOLO3VideoTrainTransform(object):
class YOLO3VideoTrainTransformOld(object): # todo delete... new one allows both single and t label output
"""Video YOLO training transform which includes tons of image augmentations.
Parameters
----------
width : int
Expand All @@ -161,6 +161,7 @@ class YOLO3VideoTrainTransform(object):
box_norm : array-like of size 4, default is (0.1, 0.1, 0.2, 0.2)
Std value to be divided from encoded values.
"""

def __init__(self, k, width, height, net=None, mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225), mixup=False, **kwargs):
self._k = k
Expand Down Expand Up @@ -197,7 +198,7 @@ def __call__(self, src, label):
was_three = True

# random color jittering
img = tvideo.random_color_distort(img)
img = tvideo.random_color_distort(img)

# random expansion with prob 0.5
if np.random.uniform(0, 1) > 0.5:
Expand All @@ -210,12 +211,12 @@ def __call__(self, src, label):
k, h, w, c = img.shape
bbox, crop = experimental.bbox.random_crop_with_constraints(bbox, (w, h))
x0, y0, w, h = crop
img = img[:, y0:y0+h, x0:x0+w, :]
img = img[:, y0:y0 + h, x0:x0 + w, :]

# resize with random interpolation
k, h, w, c = img.shape
interp = np.random.randint(0, 5)
tmp = mx.nd.ones((k, self._height, self._width, c), ctx=img.context)
tmp = mx.nd.ones((k, self._height, self._width, c), ctx=img.context)
for i in range(k):
tmp[i] = timage.imresize(img[i], self._width, self._height, interp=interp)
img = tmp
Expand All @@ -234,7 +235,7 @@ def __call__(self, src, label):

if was_three: # remove the k dimension so backwards compat with single frame
img = mx.nd.squeeze(img)

if self._target_generator is None:
return img, bbox.astype(img.dtype)

Expand All @@ -252,6 +253,151 @@ def __call__(self, src, label):
class_targets[0], gt_bboxes[0])


class YOLO3VideoTrainTransform(object):
"""Video YOLO training transform which includes tons of image augmentations.
Parameters
----------
width : int
Image width.
height : int
Image height.
net : mxnet.gluon.HybridBlock, optional
The yolo network.
.. hint::
If net is ``None``, the transformation will not generate training targets.
Otherwise it will generate training targets to accelerate the training phase
since we push some workload to CPU workers instead of GPUs.
mean : array-like of size 3
Mean pixel values to be subtracted from image tensor. Default is [0.485, 0.456, 0.406].
std : array-like of size 3
Standard deviation to be divided from image. Default is [0.229, 0.224, 0.225].
iou_thresh : float
IOU overlap threshold for maximum matching, default is 0.5.
box_norm : array-like of size 4, default is (0.1, 0.1, 0.2, 0.2)
Std value to be divided from encoded values.
"""
def __init__(self, k, width, height, net=None, mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225), mixup=False, **kwargs):
self._k = k
self._width = width
self._height = height
self._mean = mean
self._std = std
self._mixup = mixup
self._target_generator = None
self._pad = True
if net is None:
return

# in case network has reset_ctx to gpu
if k > 1:
self._fake_x = mx.nd.zeros((1, k, 3, height, width))
else:
self._fake_x = mx.nd.zeros((1, 3, height, width))
net = copy.deepcopy(net)
net.collect_params().reset_ctx(None)
with autograd.train_mode():
_, self._anchors, self._offsets, self._feat_maps, _, _, _, _ = net(self._fake_x)

self._fake_x = mx.nd.zeros((1, 3, height, width))
from gluoncv.model_zoo.yolo.yolo_target import YOLOV3PrefetchTargetGenerator
self._target_generator = YOLOV3PrefetchTargetGenerator(num_class=len(net.classes), **kwargs)

def __call__(self, src, label):
"""Apply transform to training image/label."""

img = src
was_three = False
if len(img.shape) == 3:
img = mx.nd.expand_dims(img, axis=0)
was_three = True

# random color jittering
img = tvideo.random_color_distort(img)

# random expansion with prob 0.5
if np.random.uniform(0, 1) > 0.5:
img, expand = tvideo.random_expand(img, fill=[m * 255 for m in self._mean])
bbox = tbbox.translate(label, x_offset=expand[0], y_offset=expand[1])
else:
img, bbox = img, label

# random cropping
k, h, w, c = img.shape
bbox, crop = tbbox.random_crop_with_constraints(bbox, (w, h))
x0, y0, w, h = crop
img = img[:, y0:y0+h, x0:x0+w, :]

# resize with random interpolation
k, h, w, c = img.shape
interp = np.random.randint(0, 5)
tmp = mx.nd.ones((k, self._height, self._width, c), ctx=img.context)
for i in range(k):
tmp[i] = timage.imresize(img[i], self._width, self._height, interp=interp)
img = tmp
bbox = tbbox.resize(bbox, (w, h), (self._width, self._height))

# random horizontal flip with prob 0.5
k, h, w, c = img.shape
if np.random.uniform(0, 1) > 0.5:
img = mx.nd.flip(img, axis=2)
bbox = tbbox.flip(bbox, (w, h), flip_x=True)

img = mx.nd.image.to_tensor(img) # to tensor, also transforms from k,h,w,c to k,c,h,w
# normalise
for i in range(k):
img[i] = mx.nd.image.normalize(img[i], mean=self._mean, std=self._std) # normalise

if was_three: # remove the k dimension so backwards compat with single frame
img = mx.nd.squeeze(img)

if self._target_generator is None:
return img, bbox.astype(img.dtype)

bboxs = bbox

max_boxes = 0
gt_bboxes_t = mx.nd.ones((len(bboxs), 100, 4))*-1 # max is 100
objectness_t = list()
center_targets_t = list()
scale_targets_t = list()
weights_t = list()
class_targets_t = list()
for ts, bbox in enumerate(bboxs):
# generate training target so cpu workers can help reduce the workload on gpu
gt_bboxes = mx.nd.array(bbox[np.newaxis, :, :4])
gt_ids = mx.nd.array(bbox[np.newaxis, :, 4:5])
if self._mixup:
gt_mixratio = mx.nd.array(bbox[np.newaxis, :, -1:])
else:
gt_mixratio = None
objectness, center_targets, scale_targets, weights, class_targets = self._target_generator(
self._fake_x, self._feat_maps, self._anchors, self._offsets,
gt_bboxes, gt_ids, gt_mixratio)

if len(bboxs) == 1:
return (img, objectness[0], center_targets[0], scale_targets[0], weights[0],
class_targets[0], gt_bboxes[0])

objectness_t.append(objectness)
center_targets_t.append(center_targets)
scale_targets_t.append(scale_targets)
weights_t.append(weights)
class_targets_t.append(class_targets)

max_boxes = max(max_boxes, gt_bboxes.shape[1])
gt_bboxes_t[ts, :gt_bboxes.shape[1], :] = gt_bboxes[0]

objectness_t = mx.nd.concat(*objectness_t, dim=0)
center_targets_t = mx.nd.concat(*center_targets_t, dim=0)
scale_targets_t = mx.nd.concat(*scale_targets_t, dim=0)
weights_t = mx.nd.concat(*weights_t, dim=0)
class_targets_t = mx.nd.concat(*class_targets_t, dim=0)

return img, objectness_t, center_targets_t, scale_targets_t, weights_t, class_targets_t, gt_bboxes_t[:, :max_boxes, :]


class YOLO3VideoInferenceTransform(object):
"""Default YOLO validation transform.
Parameters
Expand Down Expand Up @@ -293,7 +439,16 @@ def __call__(self, src, label, idx=None):

if was_three: # remove the k dimension so backwards compat with single frame
img = mx.nd.squeeze(img)


# if multiple temporal outputs
if isinstance(bbox, list):
max_boxes = 0
gt_bboxes_t = mx.nd.ones((len(bbox), 100, 5)) * -1 # max is 100
for t in range(len(bbox)):
max_boxes = max(max_boxes, bbox[t].shape[0])
gt_bboxes_t[t, :bbox[t].shape[0], :] = bbox[t].astype(gt_bboxes_t.dtype)
bbox = gt_bboxes_t[:, :max_boxes, :]

if idx is not None:
return img, bbox.astype(img.dtype), idx
return img, bbox.astype(img.dtype)
Expand Down

0 comments on commit c821f4d

Please sign in to comment.