Skip to content

Commit

Permalink
[Feature] Add MultiImageMixDataset (#1105)
Browse files Browse the repository at this point in the history
* Fix typo in usage example

* original MultiImageMixDataset code in mmdet

* Add MultiImageMixDataset unittests in test_dataset_wrapper

* fix lint error

* fix value name ann_file to ann_dir

* modify retrieve_data_cfg (#1)

* remove dynamic_scale & add palette

* modify retrieve_data_cfg method

* modify retrieve_data_cfg func

* fix error

* improve the unittests coverage

* fix unittests error

* Dataset (#2)

* add cfg-options

* Add unittest in test_build_dataset

* add blank line

* add blank line

* add a blank line

Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>

Co-authored-by: Younghoon-Lee <72462227+Younghoon-Lee@users.noreply.github.com>
Co-authored-by: MeowZheng <meowzheng@outlook.com>
Co-authored-by: Miao Zheng <76149310+MeowZheng@users.noreply.github.com>
  • Loading branch information
4 people authored Jan 11, 2022
1 parent f0262fa commit 6c3e63e
Show file tree
Hide file tree
Showing 6 changed files with 190 additions and 16 deletions.
5 changes: 3 additions & 2 deletions mmseg/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from .coco_stuff import COCOStuffDataset
from .custom import CustomDataset
from .dark_zurich import DarkZurichDataset
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .dataset_wrappers import (ConcatDataset, MultiImageMixDataset,
RepeatDataset)
from .drive import DRIVEDataset
from .hrf import HRFDataset
from .loveda import LoveDADataset
Expand All @@ -21,5 +22,5 @@
'PascalVOCDataset', 'ADE20KDataset', 'PascalContextDataset',
'PascalContextDataset59', 'ChaseDB1Dataset', 'DRIVEDataset', 'HRFDataset',
'STAREDataset', 'DarkZurichDataset', 'NightDrivingDataset',
'COCOStuffDataset', 'LoveDADataset'
'COCOStuffDataset', 'LoveDADataset', 'MultiImageMixDataset'
]
8 changes: 7 additions & 1 deletion mmseg/datasets/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,18 @@ def _concat_dataset(cfg, default_args=None):

def build_dataset(cfg, default_args=None):
"""Build datasets."""
from .dataset_wrappers import ConcatDataset, RepeatDataset
from .dataset_wrappers import (ConcatDataset, RepeatDataset,
MultiImageMixDataset)
if isinstance(cfg, (list, tuple)):
dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
elif cfg['type'] == 'RepeatDataset':
dataset = RepeatDataset(
build_dataset(cfg['dataset'], default_args), cfg['times'])
elif cfg['type'] == 'MultiImageMixDataset':
cp_cfg = copy.deepcopy(cfg)
cp_cfg['dataset'] = build_dataset(cp_cfg['dataset'])
cp_cfg.pop('type')
dataset = MultiImageMixDataset(**cp_cfg)
elif isinstance(cfg.get('img_dir'), (list, tuple)) or isinstance(
cfg.get('split', None), (list, tuple)):
dataset = _concat_dataset(cfg, default_args)
Expand Down
91 changes: 89 additions & 2 deletions mmseg/datasets/dataset_wrappers.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
import bisect
import collections
import copy
from itertools import chain

import mmcv
import numpy as np
from mmcv.utils import print_log
from mmcv.utils import build_from_cfg, print_log
from torch.utils.data.dataset import ConcatDataset as _ConcatDataset

from .builder import DATASETS
from .builder import DATASETS, PIPELINES
from .cityscapes import CityscapesDataset


Expand Down Expand Up @@ -188,3 +190,88 @@ def __getitem__(self, idx):
def __len__(self):
"""The length is multiplied by ``times``"""
return self.times * self._ori_len


@DATASETS.register_module()
class MultiImageMixDataset:
"""A wrapper of multiple images mixed dataset.
Suitable for training on multiple images mixed data augmentation like
mosaic and mixup. For the augmentation pipeline of mixed image data,
the `get_indexes` method needs to be provided to obtain the image
indexes, and you can set `skip_flags` to change the pipeline running
process.
Args:
dataset (:obj:`CustomDataset`): The dataset to be mixed.
pipeline (Sequence[dict]): Sequence of transform object or
config dict to be composed.
skip_type_keys (list[str], optional): Sequence of type string to
be skip pipeline. Default to None.
"""

def __init__(self, dataset, pipeline, skip_type_keys=None):
assert isinstance(pipeline, collections.abc.Sequence)
if skip_type_keys is not None:
assert all([
isinstance(skip_type_key, str)
for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys

self.pipeline = []
self.pipeline_types = []
for transform in pipeline:
if isinstance(transform, dict):
self.pipeline_types.append(transform['type'])
transform = build_from_cfg(transform, PIPELINES)
self.pipeline.append(transform)
else:
raise TypeError('pipeline must be a dict')

self.dataset = dataset
self.CLASSES = dataset.CLASSES
self.PALETTE = dataset.PALETTE
self.num_samples = len(dataset)

def __len__(self):
return self.num_samples

def __getitem__(self, idx):
results = copy.deepcopy(self.dataset[idx])
for (transform, transform_type) in zip(self.pipeline,
self.pipeline_types):
if self._skip_type_keys is not None and \
transform_type in self._skip_type_keys:
continue

if hasattr(transform, 'get_indexes'):
indexes = transform.get_indexes(self.dataset)
if not isinstance(indexes, collections.abc.Sequence):
indexes = [indexes]
mix_results = [
copy.deepcopy(self.dataset[index]) for index in indexes
]
results['mix_results'] = mix_results

results = transform(results)

if 'mix_results' in results:
results.pop('mix_results')

return results

def update_skip_type_keys(self, skip_type_keys):
"""Update skip_type_keys.
It is called by an external hook.
Args:
skip_type_keys (list[str], optional): Sequence of type
string to be skip pipeline.
"""
assert all([
isinstance(skip_type_key, str) for skip_type_key in skip_type_keys
])
self._skip_type_keys = skip_type_keys
63 changes: 62 additions & 1 deletion tests/test_data/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
from mmseg.core.evaluation import get_classes, get_palette
from mmseg.datasets import (DATASETS, ADE20KDataset, CityscapesDataset,
ConcatDataset, CustomDataset, LoveDADataset,
PascalVOCDataset, RepeatDataset, build_dataset)
MultiImageMixDataset, PascalVOCDataset,
RepeatDataset, build_dataset)


def test_classes():
Expand Down Expand Up @@ -95,6 +96,66 @@ def test_dataset_wrapper():
assert repeat_dataset[27] == 7
assert len(repeat_dataset) == 10 * len(dataset_a)

img_scale = (60, 60)
pipeline = [
# dict(type='Mosaic', img_scale=img_scale, pad_val=255),
# need to merge mosaic
dict(type='RandomFlip', prob=0.5),
dict(type='Resize', img_scale=img_scale, keep_ratio=False),
]

CustomDataset.load_annotations = MagicMock()
results = []
for _ in range(2):
height = np.random.randint(10, 30)
weight = np.random.randint(10, 30)
img = np.ones((height, weight, 3))
gt_semantic_seg = np.random.randint(5, size=(height, weight))
results.append(dict(gt_semantic_seg=gt_semantic_seg, img=img))

classes = ['0', '1', '2', '3', '4']
palette = [(0, 0, 0), (1, 1, 1), (2, 2, 2), (3, 3, 3), (4, 4, 4)]
CustomDataset.__getitem__ = MagicMock(side_effect=lambda idx: results[idx])
dataset_a = CustomDataset(
img_dir=MagicMock(),
pipeline=[],
test_mode=True,
classes=classes,
palette=palette)
len_a = 2
cat_ids_list_a = [
np.random.randint(0, 80, num).tolist()
for num in np.random.randint(1, 20, len_a)
]
dataset_a.data_infos = MagicMock()
dataset_a.data_infos.__len__.return_value = len_a
dataset_a.get_cat_ids = MagicMock(
side_effect=lambda idx: cat_ids_list_a[idx])

multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)
assert len(multi_image_mix_dataset) == len(dataset_a)

for idx in range(len_a):
results_ = multi_image_mix_dataset[idx]

# test skip_type_keys
multi_image_mix_dataset = MultiImageMixDataset(
dataset_a, pipeline, skip_type_keys=('RandomFlip'))
for idx in range(len_a):
results_ = multi_image_mix_dataset[idx]
assert results_['img'].shape == (img_scale[0], img_scale[1], 3)

skip_type_keys = ('RandomFlip', 'Resize')
multi_image_mix_dataset.update_skip_type_keys(skip_type_keys)
for idx in range(len_a):
results_ = multi_image_mix_dataset[idx]
assert results_['img'].shape[:2] != img_scale

# test pipeline
with pytest.raises(TypeError):
pipeline = [['Resize']]
multi_image_mix_dataset = MultiImageMixDataset(dataset_a, pipeline)


def test_custom_dataset():
img_norm_cfg = dict(
Expand Down
9 changes: 7 additions & 2 deletions tests/test_data/test_dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from torch.utils.data import (DistributedSampler, RandomSampler,
SequentialSampler)

from mmseg.datasets import (DATASETS, ConcatDataset, build_dataloader,
build_dataset)
from mmseg.datasets import (DATASETS, ConcatDataset, MultiImageMixDataset,
build_dataloader, build_dataset)


@DATASETS.register_module()
Expand Down Expand Up @@ -48,6 +48,11 @@ def test_build_dataset():
assert isinstance(dataset, ConcatDataset)
assert len(dataset) == 10

cfg = dict(type='MultiImageMixDataset', dataset=cfg, pipeline=[])
dataset = build_dataset(cfg)
assert isinstance(dataset, MultiImageMixDataset)
assert len(dataset) == 10

# with ann_dir, split
cfg = dict(
type='CustomDataset',
Expand Down
30 changes: 22 additions & 8 deletions tools/browse_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import mmcv
import numpy as np
from mmcv import Config
from mmcv import Config, DictAction

from mmseg.datasets.builder import build_dataset

Expand Down Expand Up @@ -42,6 +42,16 @@ def parse_args():
type=float,
default=0.5,
help='the opacity of semantic map')
parser.add_argument(
'--cfg-options',
nargs='+',
action=DictAction,
help='override some settings in the used config, the key-value pair '
'in xxx=yyy format will be merged into config file. If the value to '
'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
'Note that the quotation marks are necessary and that no white space '
'is allowed.')
args = parser.parse_args()
return args

Expand Down Expand Up @@ -122,28 +132,32 @@ def _retrieve_data_cfg(_data_cfg, skip_type, show_origin):
]


def retrieve_data_cfg(config_path, skip_type, show_origin=False):
def retrieve_data_cfg(config_path, skip_type, cfg_options, show_origin=False):
cfg = Config.fromfile(config_path)
if cfg_options is not None:
cfg.merge_from_dict(cfg_options)
train_data_cfg = cfg.data.train
if isinstance(train_data_cfg, list):
for _data_cfg in train_data_cfg:
while 'dataset' in _data_cfg and _data_cfg[
'type'] != 'MultiImageMixDataset':
_data_cfg = _data_cfg['dataset']
if 'pipeline' in _data_cfg:
_retrieve_data_cfg(_data_cfg, skip_type, show_origin)
elif 'dataset' in _data_cfg:
_retrieve_data_cfg(_data_cfg['dataset'], skip_type,
show_origin)
else:
raise ValueError
elif 'dataset' in train_data_cfg:
_retrieve_data_cfg(train_data_cfg['dataset'], skip_type, show_origin)
else:
while 'dataset' in train_data_cfg and train_data_cfg[
'type'] != 'MultiImageMixDataset':
train_data_cfg = train_data_cfg['dataset']
_retrieve_data_cfg(train_data_cfg, skip_type, show_origin)
return cfg


def main():
args = parse_args()
cfg = retrieve_data_cfg(args.config, args.skip_type, args.show_origin)
cfg = retrieve_data_cfg(args.config, args.skip_type, args.cfg_options,
args.show_origin)
dataset = build_dataset(cfg.data.train)
progress_bar = mmcv.ProgressBar(len(dataset))
for item in dataset:
Expand Down

0 comments on commit 6c3e63e

Please sign in to comment.