Skip to content

Commit

Permalink
Using Producer-Consumer for batches (OpenNMT#1450)
Browse files Browse the repository at this point in the history
* Working queues on multi-GPU on text and audio
* Working quite well, even with dynamic_dict
* Remove explicit garbage collect making some queue hang and other fixes
* fix process not ending
* properly set random seed and fill queues sequentially
* make queues work with distributed training
  • Loading branch information
pltrdy authored and vince62s committed Jun 3, 2019
1 parent 8292a0e commit b731e04
Show file tree
Hide file tree
Showing 8 changed files with 175 additions and 53 deletions.
1 change: 1 addition & 0 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _dynamic_dict(example, src_field, tgt_field):
# Map source tokens to indices in the dynamic dict.
src_map = torch.LongTensor([src_ex_vocab.stoi[w] for w in src])
example["src_map"] = src_map
example["src_ex_vocab"] = src_ex_vocab

if "tgt" in example:
tgt = tgt_field.tokenize(example["tgt"])
Expand Down
72 changes: 46 additions & 26 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

import torch
import torchtext.data
from torchtext.data import Field
from torchtext.data import Field, RawField
from torchtext.vocab import Vocab
from torchtext.data.utils import RandomShuffler

Expand Down Expand Up @@ -126,6 +126,9 @@ def get_fields(
postprocessing=make_src, sequential=False)
fields["src_map"] = src_map

src_ex_vocab = RawField()
fields["src_ex_vocab"] = src_ex_vocab

align = Field(
use_vocab=False, dtype=torch.long,
postprocessing=make_tgt, sequential=False)
Expand Down Expand Up @@ -517,12 +520,12 @@ def _pool(data, batch_size, batch_size_fn, batch_size_multiple,
for p in torchtext.data.batch(
data, batch_size * pool_factor,
batch_size_fn=batch_size_fn):
p_batch = batch_iter(
p_batch = list(batch_iter(
sorted(p, key=sort_key),
batch_size,
batch_size_fn=batch_size_fn,
batch_size_multiple=batch_size_multiple)
for b in random_shuffler(list(p_batch)):
batch_size_multiple=batch_size_multiple))
for b in random_shuffler(p_batch):
yield b


Expand Down Expand Up @@ -606,18 +609,22 @@ class MultipleDatasetIterator(object):
respective weights, and yields a batch in the wanted proportions.
"""
def __init__(self,
iterables,
train_shards,
fields,
device,
opt):
self.index = -1
self.iterators = [iter(iterable) for iterable in iterables]
self.iterables = iterables
self.iterables = []
for shard in train_shards:
self.iterables.append(
build_dataset_iter(shard, fields, opt, multi=True))
self.init_iterators = True
self.weights = opt.data_weights
self.batch_size = opt.batch_size
self.batch_size_fn = max_tok_len \
if opt.batch_type == "tokens" else None
self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
self.device = "cuda" if device >= 0 else "cpu"
self.device = device
# Temporarily load one shard to retrieve sort_key for data_type
temp_dataset = torch.load(self.iterables[0]._paths[0])
self.sort_key = temp_dataset.sort_key
Expand All @@ -626,6 +633,9 @@ def __init__(self,
del temp_dataset

def _iter_datasets(self):
if self.init_iterators:
self.iterators = [iter(iterable) for iterable in self.iterables]
self.init_iterators = False
for weight in self.weights:
self.index = (self.index + 1) % len(self.iterators)
for i in range(weight):
Expand All @@ -636,18 +646,19 @@ def _iter_examples(self):
yield next(iterator)

def __iter__(self):
for minibatch in _pool(
self._iter_examples(),
self.batch_size,
self.batch_size_fn,
self.batch_size_multiple,
self.sort_key,
self.random_shuffler,
self.pool_factor):
minibatch = sorted(minibatch, key=self.sort_key, reverse=True)
yield torchtext.data.Batch(minibatch,
self.iterables[0].dataset,
self.device)
while True:
for minibatch in _pool(
self._iter_examples(),
self.batch_size,
self.batch_size_fn,
self.batch_size_multiple,
self.sort_key,
self.random_shuffler,
self.pool_factor):
minibatch = sorted(minibatch, key=self.sort_key, reverse=True)
yield torchtext.data.Batch(minibatch,
self.iterables[0].dataset,
self.device)


class DatasetLazyIter(object):
Expand Down Expand Up @@ -679,9 +690,9 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
self.pool_factor = pool_factor

def _iter_dataset(self, path):
logger.info('Loading dataset from %s' % path)
cur_dataset = torch.load(path)
logger.info('Loading dataset from %s, number of examples: %d' %
(path, len(cur_dataset)))
logger.info('number of examples: %d' % len(cur_dataset))
cur_dataset.fields = self.fields
cur_iter = OrderedIterator(
dataset=cur_dataset,
Expand All @@ -700,10 +711,12 @@ def _iter_dataset(self, path):
self.dataset = cur_iter.dataset
yield batch

cur_dataset.examples = None
gc.collect()
del cur_dataset
gc.collect()
# NOTE: This is causing some issues for consumer/producer,
# as we may still have some of those examples in some queue
# cur_dataset.examples = None
# gc.collect()
# del cur_dataset
# gc.collect()

def __iter__(self):
num_batches = 0
Expand Down Expand Up @@ -758,6 +771,8 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
"""
dataset_paths = list(sorted(
glob.glob(opt.data + '.' + corpus_type + '.[0-9]*.pt')))
assert dataset_paths != [], \
"Check data %s - %s" % (opt.data, corpus_type)
if not dataset_paths:
return None
if multi:
Expand All @@ -784,3 +799,8 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
repeat=not opt.single_pass,
num_batches_multiple=max(opt.accum_count) * opt.world_size,
yield_raw_example=multi)


def build_dataset_iter_multiple(train_shards, fields, opt):
return MultipleDatasetIterator(
train_shards, fields, "cuda" if opt.gpu_ranks else "cpu", opt)
15 changes: 10 additions & 5 deletions onmt/modules/copy_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from onmt.utils.loss import LossComputeBase


def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs,
def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None,
batch_dim=1, batch_offset=None):
"""
Given scores from an expanded dictionary
Expand All @@ -16,9 +16,14 @@ def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs,
for b in range(scores.size(batch_dim)):
blank = []
fill = []
batch_id = batch_offset[b] if batch_offset is not None else b
index = batch.indices.data[batch_id]
src_vocab = src_vocabs[index]

if src_vocabs is None:
src_vocab = batch.src_ex_vocab[b]
else:
batch_id = batch_offset[b] if batch_offset is not None else b
index = batch.indices.data[batch_id]
src_vocab = src_vocabs[index]

for i in range(1, len(src_vocab)):
sw = src_vocab.itos[i]
ti = tgt_vocab.stoi[sw]
Expand Down Expand Up @@ -216,7 +221,7 @@ def _compute_loss(self, batch, output, target, copy_attn, align):
# and is used only for stats
scores_data = collapse_copy_scores(
self._unbottle(scores.clone(), batch.batch_size),
batch, self.tgt_vocab, batch.dataset.src_vocabs)
batch, self.tgt_vocab, None)
scores_data = self._bottle(scores_data)

# this block does not depend on the loss value computed above
Expand Down
2 changes: 2 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,8 @@ def train_opts(parser):
help="IP of master for torch.distributed training.")
group.add('--master_port', '-master_port', default=10000, type=int,
help="Port of master for torch.distributed training.")
group.add('--queue_size', '-queue_size', default=400, type=int,
help="Size of queue for each process in producer/consumer")

group.add('--seed', '-seed', type=int, default=-1,
help="Random seed used for the experiments "
Expand Down
34 changes: 23 additions & 11 deletions onmt/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch

from onmt.inputters.inputter import build_dataset_iter, \
load_old_vocab, old_style_vocab, MultipleDatasetIterator
load_old_vocab, old_style_vocab, build_dataset_iter_multiple
from onmt.model_builder import build_model
from onmt.utils.optimizers import Optimizer
from onmt.utils.misc import set_random_seed
Expand Down Expand Up @@ -39,7 +39,7 @@ def configure_process(opt, device_id):
set_random_seed(opt.seed, device_id >= 0)


def main(opt, device_id):
def main(opt, device_id, batch_queue=None, semaphore=None):
# NOTE: It's important that ``opt`` has been validated and updated
# at this point.
configure_process(opt, device_id)
Expand All @@ -51,7 +51,6 @@ def main(opt, device_id):
logger.info('Loading checkpoint from %s' % opt.train_from)
checkpoint = torch.load(opt.train_from,
map_location=lambda storage, loc: storage)

model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"])
ArgumentParser.update_model_opts(model_opt)
ArgumentParser.validate_model_opts(model_opt)
Expand Down Expand Up @@ -98,15 +97,27 @@ def main(opt, device_id):
trainer = build_trainer(
opt, device_id, model, fields, optim, model_saver=model_saver)

train_iterables = []
if len(opt.data_ids) > 1:
for train_id in opt.data_ids:
shard_base = "train_" + train_id
iterable = build_dataset_iter(shard_base, fields, opt, multi=True)
train_iterables.append(iterable)
train_iter = MultipleDatasetIterator(train_iterables, device_id, opt)
if batch_queue is None:
if len(opt.data_ids) > 1:
train_shards = []
for train_id in opt.data_ids:
shard_base = "train_" + train_id
train_shards.append(shard_base)
train_iter = build_dataset_iter_multiple(train_shards, fields, opt)
else:
train_iter = build_dataset_iter("train", fields, opt)

else:
train_iter = build_dataset_iter("train", fields, opt)
assert semaphore is not None, \
"Using batch_queue requires semaphore as well"

def _train_iter():
while True:
batch = batch_queue.get()
semaphore.release()
yield batch

train_iter = _train_iter()

valid_iter = build_dataset_iter(
"valid", fields, opt, is_train=False)
Expand All @@ -119,6 +130,7 @@ def main(opt, device_id):
if opt.single_pass and train_steps > 0:
logger.warning("Option single_pass is enabled, ignoring train_steps.")
train_steps = 0

trainer.train(
train_iter,
train_steps,
Expand Down
6 changes: 0 additions & 6 deletions onmt/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
"""

from copy import deepcopy
import itertools
import torch
import traceback

Expand Down Expand Up @@ -221,14 +220,9 @@ def train(self,
report_stats = onmt.utils.Statistics()
self._start_report_manager(start_time=total_stats.start_time)

if self.n_gpu > 1:
train_iter = itertools.islice(
train_iter, self.gpu_rank, None, self.n_gpu)

for i, (batches, normalization) in enumerate(
self._accum_batches(train_iter)):
step = self.optim.training_step

# UPDATE DROPOUT
self._maybe_update_dropout(step)

Expand Down
3 changes: 2 additions & 1 deletion preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,8 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
f_iter, all_data):
has_vocab = (sub_n == 'src' and src_vocab) or \
(sub_n == 'tgt' and tgt_vocab)
if sub_f.sequential and not has_vocab:
if (hasattr(sub_f, 'sequential')
and sub_f.sequential and not has_vocab):
val = fd
counters[sub_n].update(val)
if maybe_id:
Expand Down
Loading

0 comments on commit b731e04

Please sign in to comment.