From b731e04f3a8a842d18b0e475270c98b74c2198e3 Mon Sep 17 00:00:00 2001 From: Paul Tardy Date: Mon, 3 Jun 2019 20:08:59 +0200 Subject: [PATCH] Using Producer-Consumer for batches (#1450) * Working queues on multi-GPU on text and audio * Working quite well, even with dynamic_dict * Remove explicit garbage collect making some queue hang and other fixes * fix process not ending * properly set random seed and fill queues sequentially * make queues work with distributed training --- onmt/inputters/dataset_base.py | 1 + onmt/inputters/inputter.py | 72 ++++++++++++++++---------- onmt/modules/copy_generator.py | 15 ++++-- onmt/opts.py | 2 + onmt/train_single.py | 34 ++++++++---- onmt/trainer.py | 6 --- preprocess.py | 3 +- train.py | 95 ++++++++++++++++++++++++++++++++-- 8 files changed, 175 insertions(+), 53 deletions(-) diff --git a/onmt/inputters/dataset_base.py b/onmt/inputters/dataset_base.py index c561f6af84..5e895edd10 100644 --- a/onmt/inputters/dataset_base.py +++ b/onmt/inputters/dataset_base.py @@ -50,6 +50,7 @@ def _dynamic_dict(example, src_field, tgt_field): # Map source tokens to indices in the dynamic dict. src_map = torch.LongTensor([src_ex_vocab.stoi[w] for w in src]) example["src_map"] = src_map + example["src_ex_vocab"] = src_ex_vocab if "tgt" in example: tgt = tgt_field.tokenize(example["tgt"]) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 77efd49f5a..1513a581f9 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -9,7 +9,7 @@ import torch import torchtext.data -from torchtext.data import Field +from torchtext.data import Field, RawField from torchtext.vocab import Vocab from torchtext.data.utils import RandomShuffler @@ -126,6 +126,9 @@ def get_fields( postprocessing=make_src, sequential=False) fields["src_map"] = src_map + src_ex_vocab = RawField() + fields["src_ex_vocab"] = src_ex_vocab + align = Field( use_vocab=False, dtype=torch.long, postprocessing=make_tgt, sequential=False) @@ -517,12 +520,12 @@ def _pool(data, batch_size, batch_size_fn, batch_size_multiple, for p in torchtext.data.batch( data, batch_size * pool_factor, batch_size_fn=batch_size_fn): - p_batch = batch_iter( + p_batch = list(batch_iter( sorted(p, key=sort_key), batch_size, batch_size_fn=batch_size_fn, - batch_size_multiple=batch_size_multiple) - for b in random_shuffler(list(p_batch)): + batch_size_multiple=batch_size_multiple)) + for b in random_shuffler(p_batch): yield b @@ -606,18 +609,22 @@ class MultipleDatasetIterator(object): respective weights, and yields a batch in the wanted proportions. """ def __init__(self, - iterables, + train_shards, + fields, device, opt): self.index = -1 - self.iterators = [iter(iterable) for iterable in iterables] - self.iterables = iterables + self.iterables = [] + for shard in train_shards: + self.iterables.append( + build_dataset_iter(shard, fields, opt, multi=True)) + self.init_iterators = True self.weights = opt.data_weights self.batch_size = opt.batch_size self.batch_size_fn = max_tok_len \ if opt.batch_type == "tokens" else None self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1 - self.device = "cuda" if device >= 0 else "cpu" + self.device = device # Temporarily load one shard to retrieve sort_key for data_type temp_dataset = torch.load(self.iterables[0]._paths[0]) self.sort_key = temp_dataset.sort_key @@ -626,6 +633,9 @@ def __init__(self, del temp_dataset def _iter_datasets(self): + if self.init_iterators: + self.iterators = [iter(iterable) for iterable in self.iterables] + self.init_iterators = False for weight in self.weights: self.index = (self.index + 1) % len(self.iterators) for i in range(weight): @@ -636,18 +646,19 @@ def _iter_examples(self): yield next(iterator) def __iter__(self): - for minibatch in _pool( - self._iter_examples(), - self.batch_size, - self.batch_size_fn, - self.batch_size_multiple, - self.sort_key, - self.random_shuffler, - self.pool_factor): - minibatch = sorted(minibatch, key=self.sort_key, reverse=True) - yield torchtext.data.Batch(minibatch, - self.iterables[0].dataset, - self.device) + while True: + for minibatch in _pool( + self._iter_examples(), + self.batch_size, + self.batch_size_fn, + self.batch_size_multiple, + self.sort_key, + self.random_shuffler, + self.pool_factor): + minibatch = sorted(minibatch, key=self.sort_key, reverse=True) + yield torchtext.data.Batch(minibatch, + self.iterables[0].dataset, + self.device) class DatasetLazyIter(object): @@ -679,9 +690,9 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, self.pool_factor = pool_factor def _iter_dataset(self, path): + logger.info('Loading dataset from %s' % path) cur_dataset = torch.load(path) - logger.info('Loading dataset from %s, number of examples: %d' % - (path, len(cur_dataset))) + logger.info('number of examples: %d' % len(cur_dataset)) cur_dataset.fields = self.fields cur_iter = OrderedIterator( dataset=cur_dataset, @@ -700,10 +711,12 @@ def _iter_dataset(self, path): self.dataset = cur_iter.dataset yield batch - cur_dataset.examples = None - gc.collect() - del cur_dataset - gc.collect() + # NOTE: This is causing some issues for consumer/producer, + # as we may still have some of those examples in some queue + # cur_dataset.examples = None + # gc.collect() + # del cur_dataset + # gc.collect() def __iter__(self): num_batches = 0 @@ -758,6 +771,8 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): """ dataset_paths = list(sorted( glob.glob(opt.data + '.' + corpus_type + '.[0-9]*.pt'))) + assert dataset_paths != [], \ + "Check data %s - %s" % (opt.data, corpus_type) if not dataset_paths: return None if multi: @@ -784,3 +799,8 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): repeat=not opt.single_pass, num_batches_multiple=max(opt.accum_count) * opt.world_size, yield_raw_example=multi) + + +def build_dataset_iter_multiple(train_shards, fields, opt): + return MultipleDatasetIterator( + train_shards, fields, "cuda" if opt.gpu_ranks else "cpu", opt) diff --git a/onmt/modules/copy_generator.py b/onmt/modules/copy_generator.py index 8f8e687ba0..e24e57ab67 100644 --- a/onmt/modules/copy_generator.py +++ b/onmt/modules/copy_generator.py @@ -5,7 +5,7 @@ from onmt.utils.loss import LossComputeBase -def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs, +def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs=None, batch_dim=1, batch_offset=None): """ Given scores from an expanded dictionary @@ -16,9 +16,14 @@ def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs, for b in range(scores.size(batch_dim)): blank = [] fill = [] - batch_id = batch_offset[b] if batch_offset is not None else b - index = batch.indices.data[batch_id] - src_vocab = src_vocabs[index] + + if src_vocabs is None: + src_vocab = batch.src_ex_vocab[b] + else: + batch_id = batch_offset[b] if batch_offset is not None else b + index = batch.indices.data[batch_id] + src_vocab = src_vocabs[index] + for i in range(1, len(src_vocab)): sw = src_vocab.itos[i] ti = tgt_vocab.stoi[sw] @@ -216,7 +221,7 @@ def _compute_loss(self, batch, output, target, copy_attn, align): # and is used only for stats scores_data = collapse_copy_scores( self._unbottle(scores.clone(), batch.batch_size), - batch, self.tgt_vocab, batch.dataset.src_vocabs) + batch, self.tgt_vocab, None) scores_data = self._bottle(scores_data) # this block does not depend on the loss value computed above diff --git a/onmt/opts.py b/onmt/opts.py index f763e4f254..4fe4928d66 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -343,6 +343,8 @@ def train_opts(parser): help="IP of master for torch.distributed training.") group.add('--master_port', '-master_port', default=10000, type=int, help="Port of master for torch.distributed training.") + group.add('--queue_size', '-queue_size', default=400, type=int, + help="Size of queue for each process in producer/consumer") group.add('--seed', '-seed', type=int, default=-1, help="Random seed used for the experiments " diff --git a/onmt/train_single.py b/onmt/train_single.py index e55c89955a..e65002b9e9 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -5,7 +5,7 @@ import torch from onmt.inputters.inputter import build_dataset_iter, \ - load_old_vocab, old_style_vocab, MultipleDatasetIterator + load_old_vocab, old_style_vocab, build_dataset_iter_multiple from onmt.model_builder import build_model from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed @@ -39,7 +39,7 @@ def configure_process(opt, device_id): set_random_seed(opt.seed, device_id >= 0) -def main(opt, device_id): +def main(opt, device_id, batch_queue=None, semaphore=None): # NOTE: It's important that ``opt`` has been validated and updated # at this point. configure_process(opt, device_id) @@ -51,7 +51,6 @@ def main(opt, device_id): logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) - model_opt = ArgumentParser.ckpt_model_opts(checkpoint["opt"]) ArgumentParser.update_model_opts(model_opt) ArgumentParser.validate_model_opts(model_opt) @@ -98,15 +97,27 @@ def main(opt, device_id): trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) - train_iterables = [] - if len(opt.data_ids) > 1: - for train_id in opt.data_ids: - shard_base = "train_" + train_id - iterable = build_dataset_iter(shard_base, fields, opt, multi=True) - train_iterables.append(iterable) - train_iter = MultipleDatasetIterator(train_iterables, device_id, opt) + if batch_queue is None: + if len(opt.data_ids) > 1: + train_shards = [] + for train_id in opt.data_ids: + shard_base = "train_" + train_id + train_shards.append(shard_base) + train_iter = build_dataset_iter_multiple(train_shards, fields, opt) + else: + train_iter = build_dataset_iter("train", fields, opt) + else: - train_iter = build_dataset_iter("train", fields, opt) + assert semaphore is not None, \ + "Using batch_queue requires semaphore as well" + + def _train_iter(): + while True: + batch = batch_queue.get() + semaphore.release() + yield batch + + train_iter = _train_iter() valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) @@ -119,6 +130,7 @@ def main(opt, device_id): if opt.single_pass and train_steps > 0: logger.warning("Option single_pass is enabled, ignoring train_steps.") train_steps = 0 + trainer.train( train_iter, train_steps, diff --git a/onmt/trainer.py b/onmt/trainer.py index e990887755..2747d31622 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -10,7 +10,6 @@ """ from copy import deepcopy -import itertools import torch import traceback @@ -221,14 +220,9 @@ def train(self, report_stats = onmt.utils.Statistics() self._start_report_manager(start_time=total_stats.start_time) - if self.n_gpu > 1: - train_iter = itertools.islice( - train_iter, self.gpu_rank, None, self.n_gpu) - for i, (batches, normalization) in enumerate( self._accum_batches(train_iter)): step = self.optim.training_step - # UPDATE DROPOUT self._maybe_update_dropout(step) diff --git a/preprocess.py b/preprocess.py index 76a8b7b104..66ddf61789 100755 --- a/preprocess.py +++ b/preprocess.py @@ -107,7 +107,8 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): f_iter, all_data): has_vocab = (sub_n == 'src' and src_vocab) or \ (sub_n == 'tgt' and tgt_vocab) - if sub_f.sequential and not has_vocab: + if (hasattr(sub_f, 'sequential') + and sub_f.sequential and not has_vocab): val = fd counters[sub_n].update(val) if maybe_id: diff --git a/train.py b/train.py index 542b996926..d159deb2af 100755 --- a/train.py +++ b/train.py @@ -7,9 +7,14 @@ import onmt.opts as opts import onmt.utils.distributed -from onmt.utils.logging import logger +from onmt.utils.misc import set_random_seed +from onmt.utils.logging import init_logger, logger from onmt.train_single import main as single_main from onmt.utils.parse import ArgumentParser +from onmt.inputters.inputter import build_dataset_iter, \ + load_old_vocab, old_style_vocab, build_dataset_iter_multiple + +from itertools import cycle def main(opt): @@ -17,23 +22,61 @@ def main(opt): ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) + # Load checkpoint if we resume from a previous training. + if opt.train_from: + logger.info('Loading checkpoint from %s' % opt.train_from) + checkpoint = torch.load(opt.train_from, + map_location=lambda storage, loc: storage) + logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) + vocab = checkpoint['vocab'] + else: + vocab = torch.load(opt.data + '.vocab.pt') + + # check for code where vocab is saved instead of fields + # (in the future this will be done in a smarter way) + if old_style_vocab(vocab): + fields = load_old_vocab( + vocab, opt.model_type, dynamic_dict=opt.copy_attn) + else: + fields = vocab + + if len(opt.data_ids) > 1: + train_shards = [] + for train_id in opt.data_ids: + shard_base = "train_" + train_id + train_shards.append(shard_base) + train_iter = build_dataset_iter_multiple(train_shards, fields, opt) + else: + train_iter = build_dataset_iter("train", fields, opt) + nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: + queues = [] mp = torch.multiprocessing.get_context('spawn') + semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): + q = mp.Queue(opt.queue_size) + queues += [q] procs.append(mp.Process(target=run, args=( - opt, device_id, error_queue, ), daemon=True)) + opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) + producer = mp.Process(target=batch_producer, + args=(train_iter, queues, semaphore, opt,), + daemon=True) + producer.start() + error_handler.add_child(producer.pid) + for p in procs: p.join() + producer.terminate() elif nb_gpu == 1: # case 1 GPU only single_main(opt, 0) @@ -41,14 +84,58 @@ def main(opt): single_main(opt, -1) -def run(opt, device_id, error_queue): +def batch_producer(generator_to_serve, queues, semaphore, opt): + init_logger(opt.log_file) + set_random_seed(opt.seed, False) + # generator_to_serve = iter(generator_to_serve) + + def pred(x): + """ + Filters batches that belong only + to gpu_ranks of current node + """ + for rank in opt.gpu_ranks: + if x[0] % opt.world_size == rank: + return True + + generator_to_serve = filter( + pred, enumerate(generator_to_serve)) + + def next_batch(device_id): + new_batch = next(generator_to_serve) + semaphore.acquire() + return new_batch[1] + + b = next_batch(0) + + for device_id, q in cycle(enumerate(queues)): + b.dataset = None + if isinstance(b.src, tuple): + b.src = tuple([_.to(torch.device(device_id)) + for _ in b.src]) + else: + b.src = b.src.to(torch.device(device_id)) + b.tgt = b.tgt.to(torch.device(device_id)) + b.indices = b.indices.to(torch.device(device_id)) + b.alignment = b.alignment.to(torch.device(device_id)) \ + if hasattr(b, 'alignment') else None + b.src_map = b.src_map.to(torch.device(device_id)) \ + if hasattr(b, 'src_map') else None + + # hack to dodge unpicklable `dict_keys` + b.fields = list(b.fields) + q.put(b, False) + b = next_batch(device_id) + + +def run(opt, device_id, error_queue, batch_queue, semaphore): """ run process """ try: gpu_rank = onmt.utils.distributed.multi_init(opt, device_id) if gpu_rank != opt.gpu_ranks[device_id]: raise AssertionError("An error occurred in \ Distributed initialization") - single_main(opt, device_id) + single_main(opt, device_id, batch_queue, semaphore) except KeyboardInterrupt: pass # killed by parent, do nothing except Exception: