Skip to content
This repository has been archived by the owner on Jul 2, 2021. It is now read-only.

Commit

Permalink
use ConcatenatedDataset
Browse files Browse the repository at this point in the history
  • Loading branch information
Hakuyume committed Oct 7, 2017
1 parent a3721d4 commit c9ad167
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 36 deletions.
19 changes: 1 addition & 18 deletions examples/faster_rcnn/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import numpy as np

import chainer
from chainer.datasets import ConcatenatedDataset
from chainer.datasets import TransformDataset
from chainer import training
from chainer.training import extensions
Expand All @@ -17,24 +18,6 @@
from chainercv import transforms


class ConcatenatedDataset(chainer.dataset.DatasetMixin):

def __init__(self, *datasets):
self._datasets = datasets

def __len__(self):
return sum(len(dataset) for dataset in self._datasets)

def get_example(self, i):
if i < 0:
raise IndexError
for dataset in self._datasets:
if i < len(dataset):
return dataset[i]
i -= len(dataset)
raise IndexError


class Transform(object):

def __init__(self, faster_rcnn):
Expand Down
19 changes: 1 addition & 18 deletions examples/ssd/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np

import chainer
from chainer.datasets import ConcatenatedDataset
from chainer.datasets import TransformDataset
from chainer.optimizer import WeightDecay
from chainer import serializers
Expand All @@ -24,24 +25,6 @@
from chainercv.links.model.ssd import resize_with_random_interpolation


class ConcatenatedDataset(chainer.dataset.DatasetMixin):

def __init__(self, *datasets):
self._datasets = datasets

def __len__(self):
return sum(len(dataset) for dataset in self._datasets)

def get_example(self, i):
if i < 0:
raise IndexError
for dataset in self._datasets:
if i < len(dataset):
return dataset[i]
i -= len(dataset)
raise IndexError


class MultiboxTrainChain(chainer.Chain):

def __init__(self, model, alpha=1, k=3):
Expand Down

0 comments on commit c9ad167

Please sign in to comment.