Skip to content

Commit

Permalink
add mixup
Browse files Browse the repository at this point in the history
  • Loading branch information
shippingwang committed Dec 15, 2020
1 parent fff82ed commit 4783235
Show file tree
Hide file tree
Showing 6 changed files with 214 additions and 40 deletions.
6 changes: 4 additions & 2 deletions paddlevideo/loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .builder import build_dataset, build_dataloader
from .builder import build_dataset, build_dataloader, build_batch_pipeline
from .dataset import VideoDataset

__all__ = ['build_dataset', 'build_dataloader', 'VideoDataset']
__all__ = [
'build_dataset', 'build_dataloader', 'build_batch_pipeline', 'VideoDataset'
]
49 changes: 32 additions & 17 deletions paddlevideo/loader/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@
from ..utils.build_utils import build
from .pipelines.compose import Compose
from paddlevideo.utils import get_logger
import numpy as np

logger = get_logger("paddlevideo")


def build_pipeline(cfg):
"""Build pipeline.
Args:
Expand All @@ -32,7 +34,7 @@ def build_pipeline(cfg):

def build_dataset(cfg):
"""Build dataset.
Args:
Args:
cfg (dict): root config dict.
Returns:
Expand All @@ -44,37 +46,51 @@ def build_dataset(cfg):
dataset = build(cfg_dataset, DATASETS, key="format")
return dataset


def build_batch_pipeline(cfg):
batch_pipeline = build(cfg.MIX, PIPELINES)
return batch_pipeline


def build_dataloader(dataset,
batch_size,
num_workers,
places,
shuffle=True,
drop_last=True,
collate_fn=None,
**kwargs):
"""Build Paddle Dataloader.
XXX explain how the dataloader work!
Args:
dataset (paddle.dataset): A PaddlePaddle dataset object.
batch_size (int): batch size on single card.
num_worker (int): num_worker
shuffle(bool): whether to shuffle the data at every epoch.
"""
sampler = DistributedBatchSampler(
dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)
sampler = DistributedBatchSampler(dataset,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last)

#NOTE(shipping): when switch the mix operator on, such as: mixup, cutmix.
# batch like: [[img, label, attibute, ...], [imgs, label, attribute, ...], ...] will recollate to:
# [[img, img, ...], [label, label, ...], [attribute, attribute, ...], ...] as using numpy.transpose.

data_loader = DataLoader(
dataset,
batch_sampler=sampler,
places=places,
num_workers=num_workers,
return_list=True,
**kwargs)
if collate_fn is not None:
#ugly code here. collate_fn is mix op config
pipeline = build_batch_pipeline(collate_fn)
return lambda batch: np.array(pipeline(batch)).T

data_loader = DataLoader(dataset,
batch_sampler=sampler,
places=places,
num_workers=num_workers,
collate_fn=collate_fn,
return_list=True,
**kwargs)

return data_loader

Expand All @@ -84,11 +100,10 @@ def term_mp(sig_num, frame):
"""
pid = os.getpid()
pgid = os.getpgid(os.getpid())
logger.info("main proc {} exit, kill process group "
"{}".format(pid, pgid))
logger.info("main proc {} exit, kill process group " "{}".format(pid, pgid))
os.killpg(pgid, signal.SIGKILL)
return


signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
34 changes: 17 additions & 17 deletions paddlevideo/loader/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,24 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .augmentations import (Scale,
RandomCrop,
CenterCrop,
RandomFlip,
Image2Array,
Normalization)
from .augmentations import (Scale, RandomCrop, CenterCrop, RandomFlip,
Image2Array, Normalization)
from .compose import Compose
from .decode import VideoDecoder, FrameDecoder
from .sample import Sampler
from .mix import Mixup, Cutmix

__all__ = ['Scale',
'RandomCrop',
'CenterCrop',
'RandomFlip',
'Image2Array',
'Normalization',
'Compose',
'VideoDecoder',
'FrameDecoder',
'Sample',]

__all__ = [
'Scale',
'RandomCrop',
'CenterCrop',
'RandomFlip',
'Image2Array',
'Normalization',
'Compose',
'VideoDecoder',
'FrameDecoder',
'Sampler',
'Mixup',
'Cutmix',
]
83 changes: 83 additions & 0 deletions paddlevideo/loader/pipelines/mix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import numpy as np
from ..registry import PIPELINES


@PIPELINES.register()
class Mixup(object):
"""
Mixup operator.
Args:
alpha(float): alpha value.
"""
def __init__(self, alpha=0.2):
assert alpha > 0., \
'parameter alpha[%f] should > 0.0' % (alpha)
self.alpha = alpha

def __call__(self, batch):
imgs, labels = batch
bs = len(batch)
idx = np.random.permutation(bs)
lam = np.random.beta(self.alpha, self.alpha)
lams = np.array([lam] * bs, dtype=np.float32)
imgs = lam * imgs + (1 - lam) * imgs[idx]
return [imgs, labels, labels[idx], lams]


@PIPELINES.register()
class Cutmix(object):
""" Cutmix operator
Args:
alpha(float): alpha value.
"""
def __init__(self, alpha=0.2):
assert alpha > 0., \
'parameter alpha[%f] should > 0.0' % (alpha)
self.alpha = alpha

def rand_bbox(self, size, lam):
""" rand_bbox """
w = size[2]
h = size[3]
cut_rat = np.sqrt(1. - lam)
cut_w = np.int(w * cut_rat)
cut_h = np.int(h * cut_rat)

# uniform
cx = np.random.randint(w)
cy = np.random.randint(h)

bbx1 = np.clip(cx - cut_w // 2, 0, w)
bby1 = np.clip(cy - cut_h // 2, 0, h)
bbx2 = np.clip(cx + cut_w // 2, 0, w)
bby2 = np.clip(cy + cut_h // 2, 0, h)

return bbx1, bby1, bbx2, bby2

def __call__(self, batch):
imgs, labels, bs = self._unpack(batch)
idx = np.random.permutation(bs)
lam = np.random.beta(self.alpha, self.alpha)

bbx1, bby1, bbx2, bby2 = self.rand_bbox(imgs.shape, lam)
imgs[:, :, bbx1:bbx2, bby1:bby2] = imgs[idx, :, bbx1:bbx2, bby1:bby2]
lam = 1 - (float(bbx2 - bbx1) * (bby2 - bby1) /
(imgs.shape[-2] * imgs.shape[-1]))
lams = np.array([lam] * bs, dtype=np.float32)

return [imgs, labels, labels[idx], lams]
15 changes: 11 additions & 4 deletions paddlevideo/tasks/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,22 +37,27 @@ def train_model(model, dataset, cfg, parallel=True, validate=True):
"""
logger = get_logger("paddlevideo")
batch_size = cfg.DATASET.get('batch_size', 2)
#single card batch size
batch_size = cfg.DATASET.get('batch_size', 8)
places = paddle.set_device('gpu')
mix = cfg.PIPELINE.get("mix", None)
num_workers = cfg.DATASET.get('num_workers', 0)

train_dataset = dataset[0]

train_dataloader_setting = dict(
batch_size=batch_size,
# default num worker: 0, which means no subprocess will be created
num_workers=cfg.DATASET.get('num_workers', 0),
num_workers=num_workers,
collate_fn=cfg.get('MIX', None),
places=places)

train_loader = build_dataloader(train_dataset, **train_dataloader_setting)

if validate:
valid_dataset = dataset[1]
validate_dataloader_setting = dict(batch_size=batch_size,
num_workers=cfg.DATASET.get(
'num_workers', 0),
num_workers=num_workers,
places=places,
drop_last=False,
shuffle=False)
Expand Down Expand Up @@ -110,6 +115,7 @@ def evaluate(best):
metric_list.pop('lr')
tic = time.time()
for i, data in enumerate(valid_loader):

if parallel:
outputs = model._layers.val_step(data)
else:
Expand All @@ -126,6 +132,7 @@ def evaluate(best):
ips = "ips: {:.5f} instance/sec.".format(
batch_size / metric_list["batch_time"].val)
log_batch(metric_list, i, epoch, cfg.epochs, "val", ips)

ips = "ips: {:.5f} instance/sec.".format(
batch_size * metric_list["batch_time"].count /
metric_list["batch_time"].sum)
Expand Down
67 changes: 67 additions & 0 deletions tools/summary.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import os
import sys
import os.path as osp

import paddle
import paddle.nn.functional as F
from paddle.jit import to_static

__dir__ = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(__dir__, '../')))

from paddlevideo.modeling.builder import build_model
from paddlevideo.utils import get_config


def parse_args():

parser = argparse.ArgumentParser("PaddleVideo Summary")
parser.add_argument('-c',
'--config',
type=str,
default='configs/example.yaml',
help='config file path')

parser.add_argument("--img_size", type=int, default=224)

return parser.parse_args()


def _trim(cfg):
"""
Reuse the trainging config will bring useless attribute, such as: backbone.pretrained model. Trim it here.
"""
model_name = cfg.model_name
cfg = cfg.MODEL
cfg.backbone.pretrained = ""
return cfg, model_name


def main():
args = parse_args()
cfg, model_name = _trim(get_config(args.config, show=False))
print(f"Building model({model_name})...")
model = build_model(cfg)

params_info = paddle.summary(model, (1, 8, 3, 224, 224))

print(params_info)


if __name__ == "__main__":
main()

0 comments on commit 4783235

Please sign in to comment.