Skip to content

Commit

Permalink
Optimize data transforms for yolo training (PaddlePaddle#28)
Browse files Browse the repository at this point in the history
* Optimize data transforms for yolo training

* Simplify and add docstring
  • Loading branch information
willthefrog authored Nov 20, 2019
1 parent 3ff1060 commit 576b06f
Show file tree
Hide file tree
Showing 2 changed files with 389 additions and 23 deletions.
35 changes: 15 additions & 20 deletions ppdet/data/data_feed.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
from ppdet.data.transform.operators import (
DecodeImage, MixupImage, NormalizeBox, NormalizeImage, RandomDistort,
RandomFlipImage, RandomInterpImage, ResizeImage, ExpandImage, CropImage,
Permute, MultiscaleTestResize)
Permute, MultiscaleTestResize, Resize, ColorDistort, NormalizePermute,
RandomExpand, RandomCrop)
from ppdet.data.transform.arrange_sample import (
ArrangeRCNN, ArrangeEvalRCNN, ArrangeTestRCNN, ArrangeSSD, ArrangeEvalSSD,
ArrangeTestSSD, ArrangeYOLO, ArrangeEvalYOLO, ArrangeTestYOLO)
Expand Down Expand Up @@ -195,7 +196,7 @@ def __init__(self, sizes=[]):
class PadMSTest(object):
"""
Padding for multi-scale test
Args:
pad_to_stride (int): pad to multiple of strides, e.g., 32
"""
Expand Down Expand Up @@ -896,25 +897,15 @@ def __init__(self,
sample_transforms=[
DecodeImage(to_rgb=True, with_mixup=True),
MixupImage(alpha=1.5, beta=1.5),
ColorDistort(),
RandomExpand(fill_value=[123.675, 116.28, 103.53]),
RandomCrop(),
RandomFlipImage(is_normalized=False),
Resize(target_dim=608, interp='random'),
NormalizePermute(
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.120, 57.375]),
NormalizeBox(),
RandomDistort(),
ExpandImage(max_ratio=4., prob=.5,
mean=[123.675, 116.28, 103.53]),
CropImage([[1, 1, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.1, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.3, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.5, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.7, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.9, 1.0],
[1, 50, 0.3, 1.0, 0.5, 2.0, 0.0, 1.0]]),
RandomInterpImage(target_size=608),
RandomFlipImage(is_normalized=True),
NormalizeImage(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
is_scale=True,
is_channel_first=False),
Permute(to_bgr=False),
],
batch_transforms=[
RandomShape(sizes=[
Expand Down Expand Up @@ -1010,6 +1001,8 @@ def __init__(self,
sample_transforms[i] = ResizeImage(
target_size=self.image_shape[-1],
interp=trans.interp)
if isinstance(trans, Resize):
sample_transforms[i].target_dim = self.image_shape[-1]


@register
Expand Down Expand Up @@ -1066,4 +1059,6 @@ def __init__(self,
sample_transforms[i] = ResizeImage(
target_size=self.image_shape[-1],
interp=trans.interp)
if isinstance(trans, Resize):
sample_transforms[i].target_dim = self.image_shape[-1]
# yapf: enable
Loading

0 comments on commit 576b06f

Please sign in to comment.