From c9ad16710875df6811db32660cfbb5ab8fa64045 Mon Sep 17 00:00:00 2001 From: Toru Ogawa Date: Sat, 7 Oct 2017 11:11:02 +0900 Subject: [PATCH] use ConcatenatedDataset --- examples/faster_rcnn/train.py | 19 +------------------ examples/ssd/train.py | 19 +------------------ 2 files changed, 2 insertions(+), 36 deletions(-) diff --git a/examples/faster_rcnn/train.py b/examples/faster_rcnn/train.py index 1fa7ed16b7..aad30113f2 100644 --- a/examples/faster_rcnn/train.py +++ b/examples/faster_rcnn/train.py @@ -4,6 +4,7 @@ import numpy as np import chainer +from chainer.datasets import ConcatenatedDataset from chainer.datasets import TransformDataset from chainer import training from chainer.training import extensions @@ -17,24 +18,6 @@ from chainercv import transforms -class ConcatenatedDataset(chainer.dataset.DatasetMixin): - - def __init__(self, *datasets): - self._datasets = datasets - - def __len__(self): - return sum(len(dataset) for dataset in self._datasets) - - def get_example(self, i): - if i < 0: - raise IndexError - for dataset in self._datasets: - if i < len(dataset): - return dataset[i] - i -= len(dataset) - raise IndexError - - class Transform(object): def __init__(self, faster_rcnn): diff --git a/examples/ssd/train.py b/examples/ssd/train.py index ce1822aa25..8d89f4f6d6 100644 --- a/examples/ssd/train.py +++ b/examples/ssd/train.py @@ -3,6 +3,7 @@ import numpy as np import chainer +from chainer.datasets import ConcatenatedDataset from chainer.datasets import TransformDataset from chainer.optimizer import WeightDecay from chainer import serializers @@ -24,24 +25,6 @@ from chainercv.links.model.ssd import resize_with_random_interpolation -class ConcatenatedDataset(chainer.dataset.DatasetMixin): - - def __init__(self, *datasets): - self._datasets = datasets - - def __len__(self): - return sum(len(dataset) for dataset in self._datasets) - - def get_example(self, i): - if i < 0: - raise IndexError - for dataset in self._datasets: - if i < len(dataset): - return dataset[i] - i -= len(dataset) - raise IndexError - - class MultiboxTrainChain(chainer.Chain): def __init__(self, model, alpha=1, k=3):