Skip to content

Commit

Permalink
handle multiple training corpora and enable weighting
Browse files Browse the repository at this point in the history
  • Loading branch information
francoishernandez committed May 2, 2019
1 parent ded1853 commit b8a82e1
Show file tree
Hide file tree
Showing 5 changed files with 262 additions and 131 deletions.
98 changes: 92 additions & 6 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -501,9 +501,12 @@ def __init__(self,
dataset,
batch_size,
batch_size_multiple=1,
yield_raw_example=False,
**kwargs):
super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs)
self.batch_size_multiple = batch_size_multiple
self.yield_raw_example = yield_raw_example
self.dataset = dataset

def create_batches(self):
if self.train:
Expand All @@ -527,6 +530,80 @@ def _pool(data, random_shuffler):
batch_size_multiple=self.batch_size_multiple):
self.batches.append(sorted(b, key=self.sort_key))

def __iter__(self):
"""
Extended version of the definition in torchtext.data.Iterator.
Added yield_raw_example behaviour to yield a torchtext.data.Example
instead of a torchtext.data.Batch object.
"""
while True:
self.init_epoch()
for idx, minibatch in enumerate(self.batches):
# fast-forward if loaded from state
if self._iterations_this_epoch > idx:
continue
self.iterations += 1
self._iterations_this_epoch += 1
if self.sort_within_batch:
# NOTE: `rnn.pack_padded_sequence` requires that a
# minibatch be sorted by decreasing order, which
# requires reversing relative to typical sort keys
if self.sort:
minibatch.reverse()
else:
minibatch.sort(key=self.sort_key, reverse=True)
if self.yield_raw_example:
yield minibatch[0]
else:
yield torchtext.data.Batch(
minibatch,
self.dataset,
self.device)
if not self.repeat:
return


class MultipleDatasetIterator(object):
"""
This takes a list of iterable objects (DatasetLazyIter) and their
respective weights, and yields a batch in the wanted proportions.
"""
def __init__(self,
iterables,
weights,
batch_size,
batch_size_fn,
batch_size_multiple,
device):
self.index = -1
self.iterators = [iter(iterable) for iterable in iterables]
self.iterables = iterables
self.weights = weights
self.batch_size = batch_size
self.batch_size_fn = batch_size_fn
self.batch_size_multiple = batch_size_multiple
self.device = device

def _iter_datasets(self):
for weight in self.weights:
self.index = (self.index + 1) % len(self.iterators)
for i in range(weight):
yield self.iterators[self.index]

def _iter_examples(self):
for iterator in cycle(self._iter_datasets()):
yield next(iterator)

def __iter__(self):
for minibatch in batch_iter(
self._iter_examples(),
self.batch_size,
batch_size_fn=self.batch_size_fn,
batch_size_multiple=self.batch_size_multiple):
dataset = self.iterables[0].dataset
minibatch = sorted(minibatch, key=dataset.sort_key, reverse=True)
yield torchtext.data.Batch(minibatch, dataset, self.device)


class DatasetLazyIter(object):
"""Yield data from sharded dataset files.
Expand All @@ -543,7 +620,7 @@ class DatasetLazyIter(object):

def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
batch_size_multiple, device, is_train, repeat=True,
num_batches_multiple=1):
num_batches_multiple=1, yield_raw_example=True):
self._paths = dataset_paths
self.fields = fields
self.batch_size = batch_size
Expand All @@ -553,6 +630,7 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
self.is_train = is_train
self.repeat = repeat
self.num_batches_multiple = num_batches_multiple
self.yield_raw_example = yield_raw_example

def _iter_dataset(self, path):
cur_dataset = torch.load(path)
Expand All @@ -568,9 +646,11 @@ def _iter_dataset(self, path):
train=self.is_train,
sort=False,
sort_within_batch=True,
repeat=False
repeat=False,
yield_raw_example=self.yield_raw_example
)
for batch in cur_iter:
self.dataset = cur_iter.dataset
yield batch

cur_dataset.examples = None
Expand Down Expand Up @@ -623,7 +703,7 @@ def max_tok_len(new, count, sofar):
return max(src_elements, tgt_elements)


def build_dataset_iter(corpus_type, fields, opt, is_train=True):
def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
"""
This returns user-defined train/validate data iterator for the trainer
to iterate over. We implement simple ordered iterator strategy here,
Expand All @@ -633,9 +713,15 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True):
glob.glob(opt.data + '.' + corpus_type + '*.pt')))
if not dataset_paths:
return None
batch_size = opt.batch_size if is_train else opt.valid_batch_size
batch_fn = max_tok_len if is_train and opt.batch_type == "tokens" else None
batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
if multi:
batch_size = 1
batch_fn = None
batch_size_multiple = 1
else:
batch_size = opt.batch_size if is_train else opt.valid_batch_size
batch_fn = max_tok_len if is_train and opt.batch_type == "tokens"
else None
batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1

device = "cuda" if opt.gpu_ranks else "cpu"

Expand Down
16 changes: 12 additions & 4 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,12 @@ def preprocess_opts(parser):
help="Type of the source input. "
"Options are [text|img|audio].")

group.add('--train_src', '-train_src', required=True,
help="Path to the training source data")
group.add('--train_tgt', '-train_tgt', required=True,
help="Path to the training target data")
group.add('--train_src', '-train_src', required=True, nargs='+',
help="Path(s) to the training source data")
group.add('--train_tgt', '-train_tgt', required=True, nargs='+',
help="Path(s) to the training target data")
group.add('--train_ids', '-train_ids', nargs='+', default=[None],
help="ids to name training shards, used for corpus weighting")
group.add('--valid_src', '-valid_src',
help="Path to the validation source data")
group.add('--valid_tgt', '-valid_tgt',
Expand Down Expand Up @@ -308,6 +310,12 @@ def train_opts(parser):
help='Path prefix to the ".train.pt" and '
'".valid.pt" file path from preprocess.py')

group.add('--data_ids', '-data_ids', nargs='+', default=[None],
help="In case there are several corpora.")
group.add('--data_weights', '-data_weights', type=int, nargs='+', default=[1],
help="""Weights of different corpora, should follow
the same order as in -data_ids.""")

group.add('--save_model', '-save_model', default='model',
help="Model filename (the model will be saved as "
"<save_model>_N.pt where N is the number "
Expand Down
22 changes: 19 additions & 3 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

import torch

from onmt.inputters.inputter import build_dataset_iter, \
load_old_vocab, old_style_vocab
from onmt.inputters.inputter import build_dataset_iter, max_tok_len, \
load_old_vocab, old_style_vocab, MultipleDatasetIterator
from onmt.model_builder import build_model
from onmt.utils.optimizers import Optimizer
from onmt.utils.misc import set_random_seed
Expand Down Expand Up @@ -98,7 +98,23 @@ def main(opt, device_id):
trainer = build_trainer(
opt, device_id, model, fields, optim, model_saver=model_saver)

train_iter = build_dataset_iter("train", fields, opt)
train_iterables = []
for train_id in opt.data_ids:
if train_id:
shard_base = "train_" + train_id
else:
shard_base = "train"
iterable = build_dataset_iter(shard_base, fields, opt, multi=True)
train_iterables.append(iterable)

train_iter = MultipleDatasetIterator(
train_iterables,
opt.data_weights,
opt.batch_size,
max_tok_len if opt.batch_type == "tokens" else None,
8 if opt.model_dtype == "fp16" else 1,
device_id)

valid_iter = build_dataset_iter(
"valid", fields, opt, is_train=False)

Expand Down
13 changes: 10 additions & 3 deletions onmt/utils/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def validate_train_opts(cls, opt):
raise AssertionError(
"-gpu_ranks should have master(=0) rank "
"unless -world_size is greater than len(gpu_ranks).")
assert len(opt.data_ids) == len(opt.data_weights), \
"Please check -data_ids and -data_weights options!"

@classmethod
def validate_translate_opts(cls, opt):
Expand All @@ -114,9 +116,14 @@ def validate_preprocess_args(cls, opt):
"-shuffle is not implemented. Please shuffle \
your data before pre-processing."

assert os.path.isfile(opt.train_src) \
and os.path.isfile(opt.train_tgt), \
"Please check path of your train src and tgt files!"
assert len(opt.train_src) == len(opt.train_tgt), \
"Please provide same number of src and tgt train files!"

assert len(opt.train_src) == len(opt.train_ids), \
"Please provide proper -train_ids for your data!"

for file in opt.train_src + opt.train_tgt:
assert os.path.isfile(file), "Please check path of %s" % file

assert not opt.valid_src or os.path.isfile(opt.valid_src), \
"Please check path of your valid src file!"
Expand Down
Loading

0 comments on commit b8a82e1

Please sign in to comment.