From b8a82e12128037173b3dcef4c1d2eb511ab60437 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 2 May 2019 17:58:33 +0200 Subject: [PATCH] handle multiple training corpora and enable weighting --- onmt/inputters/inputter.py | 98 ++++++++++++++- onmt/opts.py | 16 ++- onmt/train_single.py | 22 +++- onmt/utils/parse.py | 13 +- preprocess.py | 244 ++++++++++++++++++++----------------- 5 files changed, 262 insertions(+), 131 deletions(-) diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index 38287ae6da..cf77ed263d 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -501,9 +501,12 @@ def __init__(self, dataset, batch_size, batch_size_multiple=1, + yield_raw_example=False, **kwargs): super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs) self.batch_size_multiple = batch_size_multiple + self.yield_raw_example = yield_raw_example + self.dataset = dataset def create_batches(self): if self.train: @@ -527,6 +530,80 @@ def _pool(data, random_shuffler): batch_size_multiple=self.batch_size_multiple): self.batches.append(sorted(b, key=self.sort_key)) + def __iter__(self): + """ + Extended version of the definition in torchtext.data.Iterator. + Added yield_raw_example behaviour to yield a torchtext.data.Example + instead of a torchtext.data.Batch object. + """ + while True: + self.init_epoch() + for idx, minibatch in enumerate(self.batches): + # fast-forward if loaded from state + if self._iterations_this_epoch > idx: + continue + self.iterations += 1 + self._iterations_this_epoch += 1 + if self.sort_within_batch: + # NOTE: `rnn.pack_padded_sequence` requires that a + # minibatch be sorted by decreasing order, which + # requires reversing relative to typical sort keys + if self.sort: + minibatch.reverse() + else: + minibatch.sort(key=self.sort_key, reverse=True) + if self.yield_raw_example: + yield minibatch[0] + else: + yield torchtext.data.Batch( + minibatch, + self.dataset, + self.device) + if not self.repeat: + return + + +class MultipleDatasetIterator(object): + """ + This takes a list of iterable objects (DatasetLazyIter) and their + respective weights, and yields a batch in the wanted proportions. + """ + def __init__(self, + iterables, + weights, + batch_size, + batch_size_fn, + batch_size_multiple, + device): + self.index = -1 + self.iterators = [iter(iterable) for iterable in iterables] + self.iterables = iterables + self.weights = weights + self.batch_size = batch_size + self.batch_size_fn = batch_size_fn + self.batch_size_multiple = batch_size_multiple + self.device = device + + def _iter_datasets(self): + for weight in self.weights: + self.index = (self.index + 1) % len(self.iterators) + for i in range(weight): + yield self.iterators[self.index] + + def _iter_examples(self): + for iterator in cycle(self._iter_datasets()): + yield next(iterator) + + def __iter__(self): + for minibatch in batch_iter( + self._iter_examples(), + self.batch_size, + batch_size_fn=self.batch_size_fn, + batch_size_multiple=self.batch_size_multiple): + dataset = self.iterables[0].dataset + minibatch = sorted(minibatch, key=dataset.sort_key, reverse=True) + yield torchtext.data.Batch(minibatch, dataset, self.device) + class DatasetLazyIter(object): """Yield data from sharded dataset files. @@ -543,7 +620,7 @@ class DatasetLazyIter(object): def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, batch_size_multiple, device, is_train, repeat=True, - num_batches_multiple=1): + num_batches_multiple=1, yield_raw_example=True): self._paths = dataset_paths self.fields = fields self.batch_size = batch_size @@ -553,6 +630,7 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, self.is_train = is_train self.repeat = repeat self.num_batches_multiple = num_batches_multiple + self.yield_raw_example = yield_raw_example def _iter_dataset(self, path): cur_dataset = torch.load(path) @@ -568,9 +646,11 @@ def _iter_dataset(self, path): train=self.is_train, sort=False, sort_within_batch=True, - repeat=False + repeat=False, + yield_raw_example=self.yield_raw_example ) for batch in cur_iter: + self.dataset = cur_iter.dataset yield batch cur_dataset.examples = None @@ -623,7 +703,7 @@ def max_tok_len(new, count, sofar): return max(src_elements, tgt_elements) -def build_dataset_iter(corpus_type, fields, opt, is_train=True): +def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): """ This returns user-defined train/validate data iterator for the trainer to iterate over. We implement simple ordered iterator strategy here, @@ -633,9 +713,15 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True): glob.glob(opt.data + '.' + corpus_type + '*.pt'))) if not dataset_paths: return None - batch_size = opt.batch_size if is_train else opt.valid_batch_size - batch_fn = max_tok_len if is_train and opt.batch_type == "tokens" else None - batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1 + if multi: + batch_size = 1 + batch_fn = None + batch_size_multiple = 1 + else: + batch_size = opt.batch_size if is_train else opt.valid_batch_size + batch_fn = max_tok_len if is_train and opt.batch_type == "tokens" + else None + batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1 device = "cuda" if opt.gpu_ranks else "cpu" diff --git a/onmt/opts.py b/onmt/opts.py index a46a844def..58012c80e8 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -189,10 +189,12 @@ def preprocess_opts(parser): help="Type of the source input. " "Options are [text|img|audio].") - group.add('--train_src', '-train_src', required=True, - help="Path to the training source data") - group.add('--train_tgt', '-train_tgt', required=True, - help="Path to the training target data") + group.add('--train_src', '-train_src', required=True, nargs='+', + help="Path(s) to the training source data") + group.add('--train_tgt', '-train_tgt', required=True, nargs='+', + help="Path(s) to the training target data") + group.add('--train_ids', '-train_ids', nargs='+', default=[None], + help="ids to name training shards, used for corpus weighting") group.add('--valid_src', '-valid_src', help="Path to the validation source data") group.add('--valid_tgt', '-valid_tgt', @@ -308,6 +310,12 @@ def train_opts(parser): help='Path prefix to the ".train.pt" and ' '".valid.pt" file path from preprocess.py') + group.add('--data_ids', '-data_ids', nargs='+', default=[None], + help="In case there are several corpora.") + group.add('--data_weights', '-data_weights', type=int, nargs='+', default=[1], + help="""Weights of different corpora, should follow + the same order as in -data_ids.""") + group.add('--save_model', '-save_model', default='model', help="Model filename (the model will be saved as " "_N.pt where N is the number " diff --git a/onmt/train_single.py b/onmt/train_single.py index 33f9a86f06..c7a091a663 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -4,8 +4,8 @@ import torch -from onmt.inputters.inputter import build_dataset_iter, \ - load_old_vocab, old_style_vocab +from onmt.inputters.inputter import build_dataset_iter, max_tok_len, \ + load_old_vocab, old_style_vocab, MultipleDatasetIterator from onmt.model_builder import build_model from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed @@ -98,7 +98,23 @@ def main(opt, device_id): trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) - train_iter = build_dataset_iter("train", fields, opt) + train_iterables = [] + for train_id in opt.data_ids: + if train_id: + shard_base = "train_" + train_id + else: + shard_base = "train" + iterable = build_dataset_iter(shard_base, fields, opt, multi=True) + train_iterables.append(iterable) + + train_iter = MultipleDatasetIterator( + train_iterables, + opt.data_weights, + opt.batch_size, + max_tok_len if opt.batch_type == "tokens" else None, + 8 if opt.model_dtype == "fp16" else 1, + device_id) + valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index f0ad6911d2..dba523406f 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -99,6 +99,8 @@ def validate_train_opts(cls, opt): raise AssertionError( "-gpu_ranks should have master(=0) rank " "unless -world_size is greater than len(gpu_ranks).") + assert len(opt.data_ids) == len(opt.data_weights), \ + "Please check -data_ids and -data_weights options!" @classmethod def validate_translate_opts(cls, opt): @@ -114,9 +116,14 @@ def validate_preprocess_args(cls, opt): "-shuffle is not implemented. Please shuffle \ your data before pre-processing." - assert os.path.isfile(opt.train_src) \ - and os.path.isfile(opt.train_tgt), \ - "Please check path of your train src and tgt files!" + assert len(opt.train_src) == len(opt.train_tgt), \ + "Please provide same number of src and tgt train files!" + + assert len(opt.train_src) == len(opt.train_ids), \ + "Please provide proper -train_ids for your data!" + + for file in opt.train_src + opt.train_tgt: + assert os.path.isfile(file), "Please check path of %s" % file assert not opt.valid_src or os.path.isfile(opt.valid_src), \ "Please check path of your valid src file!" diff --git a/preprocess.py b/preprocess.py index e36c918dea..d0a109b9ef 100755 --- a/preprocess.py +++ b/preprocess.py @@ -36,127 +36,138 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': - src = opt.train_src - tgt = opt.train_tgt + srcs = opt.train_src + tgts = opt.train_tgt + ids = opt.train_ids else: - src = opt.valid_src - tgt = opt.valid_tgt - - logger.info("Reading source and target files: %s %s." % (src, tgt)) - - src_shards = split_corpus(src, opt.shard_size) - tgt_shards = split_corpus(tgt, opt.shard_size) - shard_pairs = zip(src_shards, tgt_shards) - dataset_paths = [] - if (corpus_type == "train" or opt.filter_valid) and tgt is not None: - filter_pred = partial( - inputters.filter_example, use_src_len=opt.data_type == "text", - max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) - else: - filter_pred = None - - if corpus_type == "train": - counters = defaultdict(Counter) - if opt.src_vocab: - try: - logger.info("Using existing vocabulary...") - vocab = torch.load(opt.src_vocab) - # return vocab to dump with standard name - return vocab - except torch.serialization.pickle.UnpicklingError: - logger.info("Building vocab from text file...") - - # Load vocabulary - if opt.src_vocab: - src_vocab, src_vocab_size = _load_vocab( - opt.src_vocab, "src", counters) + srcs = [opt.valid_src] + tgts = [opt.valid_tgt] + ids = [None] + + for src, tgt, maybe_id in zip(srcs, tgts, ids): + logger.info("Reading source and target files: %s %s." % (src, tgt)) + + src_shards = split_corpus(src, opt.shard_size) + tgt_shards = split_corpus(tgt, opt.shard_size) + shard_pairs = zip(src_shards, tgt_shards) + dataset_paths = [] + if (corpus_type == "train" or opt.filter_valid) and tgt is not None: + filter_pred = partial( + inputters.filter_example, use_src_len=opt.data_type == "text", + max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) else: - src_vocab = None + filter_pred = None + + if corpus_type == "train": + counters = defaultdict(Counter) + if opt.src_vocab: + try: + logger.info("Using existing vocabulary...") + vocab = torch.load(opt.src_vocab) + # return vocab to dump with standard name + return vocab + except torch.serialization.pickle.UnpicklingError: + logger.info("Building vocab from text file...") + + # Load vocabulary + if opt.src_vocab: + src_vocab, src_vocab_size = _load_vocab( + opt.src_vocab, "src", counters) + else: + src_vocab = None + + if opt.tgt_vocab: + tgt_vocab, tgt_vocab_size = _load_vocab( + opt.tgt_vocab, "tgt", counters) + else: + tgt_vocab = None + + for i, (src_shard, tgt_shard) in enumerate(shard_pairs): + assert len(src_shard) == len(tgt_shard) + logger.info("Building shard %d." % i) + dataset = inputters.Dataset( + fields, + readers=[src_reader, tgt_reader] + if tgt_reader else [src_reader], + data=([("src", src_shard), ("tgt", tgt_shard)] + if tgt_reader else [("src", src_shard)]), + dirs=[opt.src_dir, None] + if tgt_reader else [opt.src_dir], + sort_key=inputters.str2sortkey[opt.data_type], + filter_pred=filter_pred + ) + if corpus_type == "train": + for ex in dataset.examples: + for name, field in fields.items(): + try: + f_iter = iter(field) + except TypeError: + f_iter = [(name, field)] + all_data = [getattr(ex, name, None)] + else: + all_data = getattr(ex, name) + for (sub_n, sub_f), fd in zip( + f_iter, all_data): + has_vocab = (sub_n == 'src' and src_vocab) or \ + (sub_n == 'tgt' and tgt_vocab) + if sub_f.sequential and not has_vocab: + val = fd + counters[sub_n].update(val) + if maybe_id: + shard_base = corpus_type + "_" + maybe_id + else: + shard_base = corpus_type + data_path = "{:s}.{:s}.{:d}.pt".\ + format(opt.save_data, shard_base, i) + dataset_paths.append(data_path) + + logger.info(" * saving %sth %s data shard to %s." + % (i, shard_base, data_path)) + + dataset.save(data_path) + + del dataset.examples + gc.collect() + del dataset + gc.collect() - if opt.tgt_vocab: - tgt_vocab, tgt_vocab_size = _load_vocab( - opt.tgt_vocab, "tgt", counters) - else: - tgt_vocab = None - - for i, (src_shard, tgt_shard) in enumerate(shard_pairs): - assert len(src_shard) == len(tgt_shard) - logger.info("Building shard %d." % i) - dataset = inputters.Dataset( - fields, - readers=[src_reader, tgt_reader] if tgt_reader else [src_reader], - data=([("src", src_shard), ("tgt", tgt_shard)] - if tgt_reader else [("src", src_shard)]), - dirs=[opt.src_dir, None] if tgt_reader else [opt.src_dir], - sort_key=inputters.str2sortkey[opt.data_type], - filter_pred=filter_pred - ) if corpus_type == "train": - for ex in dataset.examples: - for name, field in fields.items(): - try: - f_iter = iter(field) - except TypeError: - f_iter = [(name, field)] - all_data = [getattr(ex, name, None)] - else: - all_data = getattr(ex, name) - for (sub_n, sub_f), fd in zip( - f_iter, all_data): - has_vocab = (sub_n == 'src' and src_vocab) or \ - (sub_n == 'tgt' and tgt_vocab) - if sub_f.sequential and not has_vocab: - val = fd - counters[sub_n].update(val) - - data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) - dataset_paths.append(data_path) - - logger.info(" * saving %sth %s data shard to %s." - % (i, corpus_type, data_path)) - - dataset.save(data_path) - - del dataset.examples - gc.collect() - del dataset - gc.collect() - - if corpus_type == "train": - build_fv_args = defaultdict(dict) - build_fv_args["src"] = dict( - max_size=opt.src_vocab_size, min_freq=opt.src_words_min_frequency) - build_fv_args["tgt"] = dict( - max_size=opt.tgt_vocab_size, min_freq=opt.tgt_words_min_frequency) - tgt_multifield = fields["tgt"] - _build_fv_from_multifield( - tgt_multifield, - counters, - build_fv_args, - size_multiple=opt.vocab_size_multiple - if not opt.share_vocab else 1) - if opt.data_type == 'text': - src_multifield = fields["src"] + build_fv_args = defaultdict(dict) + build_fv_args["src"] = dict( + max_size=opt.src_vocab_size, + min_freq=opt.src_words_min_frequency) + build_fv_args["tgt"] = dict( + max_size=opt.tgt_vocab_size, + min_freq=opt.tgt_words_min_frequency) + tgt_multifield = fields["tgt"] _build_fv_from_multifield( - src_multifield, + tgt_multifield, counters, build_fv_args, size_multiple=opt.vocab_size_multiple if not opt.share_vocab else 1) - if opt.share_vocab: - # `tgt_vocab_size` is ignored when sharing vocabularies - logger.info(" * merging src and tgt vocab...") - src_field = src_multifield.base_field - tgt_field = tgt_multifield.base_field - _merge_field_vocabs( - src_field, tgt_field, vocab_size=opt.src_vocab_size, - min_freq=opt.src_words_min_frequency, - vocab_size_multiple=opt.vocab_size_multiple) - logger.info(" * merged vocab size: %d." - % len(src_field.vocab)) - - vocab_path = opt.save_data + '.vocab.pt' - torch.save(fields, vocab_path) + if opt.data_type == 'text': + src_multifield = fields["src"] + _build_fv_from_multifield( + src_multifield, + counters, + build_fv_args, + size_multiple=opt.vocab_size_multiple + if not opt.share_vocab else 1) + if opt.share_vocab: + # `tgt_vocab_size` is ignored when sharing vocabularies + logger.info(" * merging src and tgt vocab...") + src_field = src_multifield.base_field + tgt_field = tgt_multifield.base_field + _merge_field_vocabs( + src_field, tgt_field, vocab_size=opt.src_vocab_size, + min_freq=opt.src_words_min_frequency, + vocab_size_multiple=opt.vocab_size_multiple) + logger.info(" * merged vocab size: %d." + % len(src_field.vocab)) + + vocab_path = opt.save_data + '.vocab.pt' + torch.save(fields, vocab_path) def build_save_vocab(train_dataset, fields, opt): @@ -190,9 +201,12 @@ def main(opt): init_logger(opt.log_file) logger.info("Extracting features...") - src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \ - else 0 - tgt_nfeats = count_features(opt.train_tgt) # tgt always text so far + src_nfeats = 0 + tgt_nfeats = 0 + for src, tgt in zip(opt.train_src, opt.train_tgt): + src_nfeats += count_features(src) if opt.data_type == 'text' \ + else 0 + tgt_nfeats += count_features(tgt) # tgt always text so far logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats)