From fae4d620ff94113e9c0cb2cd4e71e46635b79aa9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fran=C3=A7ois=20Hernandez?= Date: Thu, 16 May 2019 16:27:11 +0200 Subject: [PATCH] Preprocessing: faster build vocab + multiple weighted datasets (#1413) * handle multiple training corpora and enable weighting * move fields vocab building logic in function * fix device handling MultipleDatasetIterator * fix multi/yield_raw_batch parameter DatasetLazyIter * update FAQ.md * add -pool_factor option * reduce pool_factor for travis runs --- .travis.yml | 10 +- docs/source/FAQ.md | 34 ++++- onmt/inputters/inputter.py | 227 +++++++++++++++++++++++++++------- onmt/opts.py | 23 +++- onmt/tests/test_preprocess.py | 12 +- onmt/train_single.py | 14 ++- onmt/utils/parse.py | 13 +- preprocess.py | 164 ++++++++++++++++-------- 8 files changed, 374 insertions(+), 123 deletions(-) diff --git a/.travis.yml b/.travis.yml index 3fd9c3fc75..f6a4d51c01 100644 --- a/.travis.yml +++ b/.travis.yml @@ -31,7 +31,7 @@ script: # test nmt preprocessing - python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data /tmp/data -src_vocab_size 1000 -tgt_vocab_size 1000 && rm -rf /tmp/data*.pt # test im2text preprocessing - - python preprocess.py -data_type img -shard_size 3 -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-train.txt -train_tgt /tmp/im2text/tgt-train.txt -valid_src /tmp/im2text/src-val.txt -valid_tgt /tmp/im2text/tgt-val.txt -save_data /tmp/im2text/data && rm -rf /tmp/im2text/data*.pt + - python preprocess.py -data_type img -shard_size 100 -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-train.txt -train_tgt /tmp/im2text/tgt-train.txt -valid_src /tmp/im2text/src-val.txt -valid_tgt /tmp/im2text/tgt-val.txt -save_data /tmp/im2text/data && rm -rf /tmp/im2text/data*.pt # test speech2text preprocessing - python preprocess.py -data_type audio -shard_size 300 -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-train.txt -train_tgt /tmp/speech/tgt-train.txt -valid_src /tmp/speech/src-val.txt -valid_tgt /tmp/speech/tgt-val.txt -save_data /tmp/speech/data && rm -rf /tmp/speech/data*.pt # test nmt translation @@ -43,14 +43,14 @@ script: # test speech2text translation - head /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python translate.py -data_type audio -src_dir /tmp/speech/an4_dataset -model /tmp/test_model_speech.pt -src /tmp/speech/src-val-head.txt -tgt /tmp/speech/tgt-val-head.txt -verbose -out /tmp/speech/trans; diff /tmp/speech/tgt-val-head.txt /tmp/speech/trans # test nmt preprocessing and training - - head data/src-val.txt > /tmp/src-val.txt; head data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/q*.pt + - head -500 data/src-val.txt > /tmp/src-val.txt; head -500 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/q*.pt # test nmt preprocessing w/ sharding and training w/copy - - head data/src-val.txt > /tmp/src-val.txt; head data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 1 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 && rm -rf /tmp/q*.pt + - head -50 data/src-val.txt > /tmp/src-val.txt; head -50 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 25 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 -pool_factor 10 && rm -rf /tmp/q*.pt # test im2text preprocessing and training - - head /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/im2text/q*.pt + - head -50 /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head -50 /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q -tgt_seq_length 100; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 -pool_factor 10 && rm -rf /tmp/im2text/q*.pt # test speech2text preprocessing and training - - head /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python preprocess.py -data_type audio -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-val-head.txt -train_tgt /tmp/speech/tgt-val-head.txt -valid_src /tmp/speech/src-val-head.txt -valid_tgt /tmp/speech/tgt-val-head.txt -save_data /tmp/speech/q; python train.py -model_type audio -data /tmp/speech/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/speech/q*.pt + - head -100 /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head -100 /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python preprocess.py -data_type audio -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-val-head.txt -train_tgt /tmp/speech/tgt-val-head.txt -valid_src /tmp/speech/src-val-head.txt -valid_tgt /tmp/speech/tgt-val-head.txt -save_data /tmp/speech/q; python train.py -model_type audio -data /tmp/speech/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 -pool_factor 10 && rm -rf /tmp/speech/q*.pt # test nmt translation - python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 10 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans # test nmt translation with random sampling diff --git a/docs/source/FAQ.md b/docs/source/FAQ.md index 9a45a87630..c5ebeb201f 100644 --- a/docs/source/FAQ.md +++ b/docs/source/FAQ.md @@ -2,9 +2,9 @@ ## How do I use Pretrained embeddings (e.g. GloVe)? -Using vocabularies from OpenNMT-py preprocessing outputs, `embeddings_to_torch.py` to generate encoder and decoder embeddings initialized with GloVe’s values. +Using vocabularies from OpenNMT-py preprocessing outputs, `embeddings_to_torch.py` to generate encoder and decoder embeddings initialized with GloVe's values. -the script is a slightly modified version of ylhsieh’s one2. +the script is a slightly modified version of ylhsieh's one2. Usage: @@ -105,4 +105,34 @@ if you use a regular network card (1 Gbps) then we suggest to use a higher accum You can specify several models in the translate.py command line: -model model1_seed1 model2_seed2 Bear in mind that your models must share the same traget vocabulary. +## How can I weight different corpora at training? +### Preprocessing + +We introduced `-train_ids` which is a list of IDs that will be given to the preprocessed shards. + +E.g. we have two corpora : `parallel.en` and `parallel.de` + `from_backtranslation.en` `from_backtranslation.de`, we can pass the following in the `preprocess.py` command: +``` +... +-train_src parallel.en from_backtranslation.en \ +-train_tgt parallel.de from_backtranslation.de \ +-train_ids A B \ +-save_data my_data \ +... +``` +and it will dump `my_data.train_A.X.pt` based on `parallel.en`//`parallel.de` and `my_data.train_B.X.pt` based on `from_backtranslation.en`//`from_backtranslation.de`. + +### Training + +We introduced `-data_ids` based on the same principle as above, as well as `-data_weights`, which is the list of the weight each corpus should have. +E.g. +``` +... +-data my_data \ +-data_ids A B \ +-data_weights 1 7 \ +... +``` +will mean that we'll look for `my_data.train_A.*.pt` and `my_data.train_B.*.pt`, and that when building batches, we'll take 1 example from corpus A, then 7 examples from corpus B, and so on. + +**Warning**: This means that we'll load as many shards as we have `-data_ids`, in order to produce batches containing data from every corpus. It may be a good idea to reduce the `-shard_size` at preprocessing. diff --git a/onmt/inputters/inputter.py b/onmt/inputters/inputter.py index b591701114..5050431346 100644 --- a/onmt/inputters/inputter.py +++ b/onmt/inputters/inputter.py @@ -11,6 +11,7 @@ import torchtext.data from torchtext.data import Field from torchtext.vocab import Vocab +from torchtext.data.utils import RandomShuffler from onmt.inputters.text_dataset import text_fields, TextMultiField from onmt.inputters.image_dataset import image_fields @@ -308,6 +309,42 @@ def _build_fv_from_multifield(multifield, counters, build_fv_args, logger.info(" * %s vocab size: %d." % (name, len(field.vocab))) +def _build_fields_vocab(fields, counters, data_type, share_vocab, + vocab_size_multiple, + src_vocab_size, src_words_min_frequency, + tgt_vocab_size, tgt_words_min_frequency): + build_fv_args = defaultdict(dict) + build_fv_args["src"] = dict( + max_size=src_vocab_size, min_freq=src_words_min_frequency) + build_fv_args["tgt"] = dict( + max_size=tgt_vocab_size, min_freq=tgt_words_min_frequency) + tgt_multifield = fields["tgt"] + _build_fv_from_multifield( + tgt_multifield, + counters, + build_fv_args, + size_multiple=vocab_size_multiple if not share_vocab else 1) + if data_type == 'text': + src_multifield = fields["src"] + _build_fv_from_multifield( + src_multifield, + counters, + build_fv_args, + size_multiple=vocab_size_multiple if not share_vocab else 1) + if share_vocab: + # `tgt_vocab_size` is ignored when sharing vocabularies + logger.info(" * merging src and tgt vocab...") + src_field = src_multifield.base_field + tgt_field = tgt_multifield.base_field + _merge_field_vocabs( + src_field, tgt_field, vocab_size=src_vocab_size, + min_freq=src_words_min_frequency, + vocab_size_multiple=vocab_size_multiple) + logger.info(" * merged vocab size: %d." % len(src_field.vocab)) + + return fields + + def build_vocab(train_dataset_files, fields, data_type, share_vocab, src_vocab_path, src_vocab_size, src_words_min_frequency, tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency, @@ -392,34 +429,12 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab, del dataset gc.collect() - build_fv_args = defaultdict(dict) - build_fv_args["src"] = dict( - max_size=src_vocab_size, min_freq=src_words_min_frequency) - build_fv_args["tgt"] = dict( - max_size=tgt_vocab_size, min_freq=tgt_words_min_frequency) - tgt_multifield = fields["tgt"] - _build_fv_from_multifield( - tgt_multifield, - counters, - build_fv_args, - size_multiple=vocab_size_multiple if not share_vocab else 1) - if data_type == 'text': - src_multifield = fields["src"] - _build_fv_from_multifield( - src_multifield, - counters, - build_fv_args, - size_multiple=vocab_size_multiple if not share_vocab else 1) - if share_vocab: - # `tgt_vocab_size` is ignored when sharing vocabularies - logger.info(" * merging src and tgt vocab...") - src_field = src_multifield.base_field - tgt_field = tgt_multifield.base_field - _merge_field_vocabs( - src_field, tgt_field, vocab_size=src_vocab_size, - min_freq=src_words_min_frequency, - vocab_size_multiple=vocab_size_multiple) - logger.info(" * merged vocab size: %d." % len(src_field.vocab)) + fields = _build_fields_vocab( + fields, counters, data_type, + share_vocab, vocab_size_multiple, + src_vocab_size, src_words_min_frequency, + tgt_vocab_size, tgt_words_min_frequency) + return fields # is the return necessary? @@ -497,29 +512,52 @@ def batch_size_fn(new, count, sofar): yield minibatch +def _pool(data, batch_size, batch_size_fn, batch_size_multiple, + sort_key, random_shuffler, pool_factor): + for p in torchtext.data.batch( + data, batch_size * pool_factor, + batch_size_fn=batch_size_fn): + p_batch = batch_iter( + sorted(p, key=sort_key), + batch_size, + batch_size_fn=batch_size_fn, + batch_size_multiple=batch_size_multiple) + for b in random_shuffler(list(p_batch)): + yield b + + class OrderedIterator(torchtext.data.Iterator): def __init__(self, dataset, batch_size, + pool_factor=1, batch_size_multiple=1, + yield_raw_example=False, **kwargs): super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs) self.batch_size_multiple = batch_size_multiple + self.yield_raw_example = yield_raw_example + self.dataset = dataset + self.pool_factor = pool_factor def create_batches(self): if self.train: - def _pool(data, random_shuffler): - for p in torchtext.data.batch(data, self.batch_size * 100): - p_batch = batch_iter( - sorted(p, key=self.sort_key), - self.batch_size, - batch_size_fn=self.batch_size_fn, - batch_size_multiple=self.batch_size_multiple) - for b in random_shuffler(list(p_batch)): - yield b - - self.batches = _pool(self.data(), self.random_shuffler) + if self.yield_raw_example: + self.batches = batch_iter( + self.data(), + 1, + batch_size_fn=None, + batch_size_multiple=1) + else: + self.batches = _pool( + self.data(), + self.batch_size, + self.batch_size_fn, + self.batch_size_multiple, + self.sort_key, + self.random_shuffler, + self.pool_factor) else: self.batches = [] for b in batch_iter( @@ -529,6 +567,88 @@ def _pool(data, random_shuffler): batch_size_multiple=self.batch_size_multiple): self.batches.append(sorted(b, key=self.sort_key)) + def __iter__(self): + """ + Extended version of the definition in torchtext.data.Iterator. + Added yield_raw_example behaviour to yield a torchtext.data.Example + instead of a torchtext.data.Batch object. + """ + while True: + self.init_epoch() + for idx, minibatch in enumerate(self.batches): + # fast-forward if loaded from state + if self._iterations_this_epoch > idx: + continue + self.iterations += 1 + self._iterations_this_epoch += 1 + if self.sort_within_batch: + # NOTE: `rnn.pack_padded_sequence` requires that a + # minibatch be sorted by decreasing order, which + # requires reversing relative to typical sort keys + if self.sort: + minibatch.reverse() + else: + minibatch.sort(key=self.sort_key, reverse=True) + if self.yield_raw_example: + yield minibatch[0] + else: + yield torchtext.data.Batch( + minibatch, + self.dataset, + self.device) + if not self.repeat: + return + + +class MultipleDatasetIterator(object): + """ + This takes a list of iterable objects (DatasetLazyIter) and their + respective weights, and yields a batch in the wanted proportions. + """ + def __init__(self, + iterables, + device, + opt): + self.index = -1 + self.iterators = [iter(iterable) for iterable in iterables] + self.iterables = iterables + self.weights = opt.data_weights + self.batch_size = opt.batch_size + self.batch_size_fn = max_tok_len \ + if opt.batch_type == "tokens" else None + self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1 + self.device = "cuda" if device >= 0 else "cpu" + # Temporarily load one shard to retrieve sort_key for data_type + temp_dataset = torch.load(self.iterables[0]._paths[0]) + self.sort_key = temp_dataset.sort_key + self.random_shuffler = RandomShuffler() + self.pool_factor = opt.pool_factor + del temp_dataset + + def _iter_datasets(self): + for weight in self.weights: + self.index = (self.index + 1) % len(self.iterators) + for i in range(weight): + yield self.iterators[self.index] + + def _iter_examples(self): + for iterator in cycle(self._iter_datasets()): + yield next(iterator) + + def __iter__(self): + for minibatch in _pool( + self._iter_examples(), + self.batch_size, + self.batch_size_fn, + self.batch_size_multiple, + self.sort_key, + self.random_shuffler, + self.pool_factor): + minibatch = sorted(minibatch, key=self.sort_key, reverse=True) + yield torchtext.data.Batch(minibatch, + self.iterables[0].dataset, + self.device) + class DatasetLazyIter(object): """Yield data from sharded dataset files. @@ -544,8 +664,8 @@ class DatasetLazyIter(object): """ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, - batch_size_multiple, device, is_train, repeat=True, - num_batches_multiple=1): + batch_size_multiple, device, is_train, pool_factor, + repeat=True, num_batches_multiple=1, yield_raw_example=False): self._paths = dataset_paths self.fields = fields self.batch_size = batch_size @@ -555,6 +675,8 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn, self.is_train = is_train self.repeat = repeat self.num_batches_multiple = num_batches_multiple + self.yield_raw_example = yield_raw_example + self.pool_factor = pool_factor def _iter_dataset(self, path): cur_dataset = torch.load(path) @@ -564,15 +686,18 @@ def _iter_dataset(self, path): cur_iter = OrderedIterator( dataset=cur_dataset, batch_size=self.batch_size, + pool_factor=self.pool_factor, batch_size_multiple=self.batch_size_multiple, batch_size_fn=self.batch_size_fn, device=self.device, train=self.is_train, sort=False, sort_within_batch=True, - repeat=False + repeat=False, + yield_raw_example=self.yield_raw_example ) for batch in cur_iter: + self.dataset = cur_iter.dataset yield batch cur_dataset.examples = None @@ -625,7 +750,7 @@ def max_tok_len(new, count, sofar): return max(src_elements, tgt_elements) -def build_dataset_iter(corpus_type, fields, opt, is_train=True): +def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False): """ This returns user-defined train/validate data iterator for the trainer to iterate over. We implement simple ordered iterator strategy here, @@ -635,9 +760,15 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True): glob.glob(opt.data + '.' + corpus_type + '*.pt'))) if not dataset_paths: return None - batch_size = opt.batch_size if is_train else opt.valid_batch_size - batch_fn = max_tok_len if is_train and opt.batch_type == "tokens" else None - batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1 + if multi: + batch_size = 1 + batch_fn = None + batch_size_multiple = 1 + else: + batch_size = opt.batch_size if is_train else opt.valid_batch_size + batch_fn = max_tok_len \ + if is_train and opt.batch_type == "tokens" else None + batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1 device = "cuda" if opt.gpu_ranks else "cpu" @@ -649,5 +780,7 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True): batch_size_multiple, device, is_train, + opt.pool_factor, repeat=not opt.single_pass, - num_batches_multiple=max(opt.accum_count) * opt.world_size) + num_batches_multiple=max(opt.accum_count) * opt.world_size, + yield_raw_example=multi) diff --git a/onmt/opts.py b/onmt/opts.py index 85cee88b1f..f763e4f254 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -189,10 +189,12 @@ def preprocess_opts(parser): help="Type of the source input. " "Options are [text|img|audio].") - group.add('--train_src', '-train_src', required=True, - help="Path to the training source data") - group.add('--train_tgt', '-train_tgt', required=True, - help="Path to the training target data") + group.add('--train_src', '-train_src', required=True, nargs='+', + help="Path(s) to the training source data") + group.add('--train_tgt', '-train_tgt', required=True, nargs='+', + help="Path(s) to the training target data") + group.add('--train_ids', '-train_ids', nargs='+', default=[None], + help="ids to name training shards, used for corpus weighting") group.add('--valid_src', '-valid_src', help="Path to the validation source data") group.add('--valid_tgt', '-valid_tgt', @@ -308,6 +310,12 @@ def train_opts(parser): help='Path prefix to the ".train.pt" and ' '".valid.pt" file path from preprocess.py') + group.add('--data_ids', '-data_ids', nargs='+', default=[None], + help="In case there are several corpora.") + group.add('--data_weights', '-data_weights', type=int, nargs='+', + default=[1], help="""Weights of different corpora, + should follow the same order as in -data_ids.""") + group.add('--save_model', '-save_model', default='model', help="Model filename (the model will be saved as " "_N.pt where N is the number " @@ -382,6 +390,13 @@ def train_opts(parser): choices=["sents", "tokens"], help="Batch grouping for batch_size. Standard " "is sents. Tokens will do dynamic batching") + group.add('--pool_factor', '-pool_factor', type=int, default=8192, + help="""Factor used in data loading and batch creations. + It will load the equivalent of `pool_factor` batches, + sort them by the according `sort_key` to produce + homogeneous batches and reduce padding, and yield + the produced batches in a shuffled way. + Inspired by torchtext's pool mechanism.""") group.add('--normalization', '-normalization', default='sents', choices=["sents", "tokens"], help='Normalization method of the gradient.') diff --git a/onmt/tests/test_preprocess.py b/onmt/tests/test_preprocess.py index 2d423e5e29..ff63b1d758 100644 --- a/onmt/tests/test_preprocess.py +++ b/onmt/tests/test_preprocess.py @@ -49,11 +49,9 @@ def dataset_build(self, opt): src_reader = onmt.inputters.str2reader[opt.data_type].from_opt(opt) tgt_reader = onmt.inputters.str2reader["text"].from_opt(opt) - train_data_files = preprocess.build_save_dataset( + preprocess.build_save_dataset( 'train', fields, src_reader, tgt_reader, opt) - preprocess.build_save_vocab(train_data_files, fields, opt) - preprocess.build_save_dataset( 'valid', fields, src_reader, tgt_reader, opt) @@ -136,8 +134,8 @@ def test_method(self): ] test_databuild_common = [('data_type', 'img'), ('src_dir', '/tmp/im2text/images'), - ('train_src', '/tmp/im2text/src-train-head.txt'), - ('train_tgt', '/tmp/im2text/tgt-train-head.txt'), + ('train_src', ['/tmp/im2text/src-train-head.txt']), + ('train_tgt', ['/tmp/im2text/tgt-train-head.txt']), ('valid_src', '/tmp/im2text/src-val-head.txt'), ('valid_tgt', '/tmp/im2text/tgt-val-head.txt'), ] @@ -164,8 +162,8 @@ def test_method(self): ] test_databuild_common = [('data_type', 'audio'), ('src_dir', '/tmp/speech/an4_dataset'), - ('train_src', '/tmp/speech/src-train-head.txt'), - ('train_tgt', '/tmp/speech/tgt-train-head.txt'), + ('train_src', ['/tmp/speech/src-train-head.txt']), + ('train_tgt', ['/tmp/speech/tgt-train-head.txt']), ('valid_src', '/tmp/speech/src-val-head.txt'), ('valid_tgt', '/tmp/speech/tgt-val-head.txt'), ('sample_rate', 16000), diff --git a/onmt/train_single.py b/onmt/train_single.py index 33f9a86f06..fb7ef8ee3e 100755 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -5,7 +5,7 @@ import torch from onmt.inputters.inputter import build_dataset_iter, \ - load_old_vocab, old_style_vocab + load_old_vocab, old_style_vocab, MultipleDatasetIterator from onmt.model_builder import build_model from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed @@ -98,7 +98,17 @@ def main(opt, device_id): trainer = build_trainer( opt, device_id, model, fields, optim, model_saver=model_saver) - train_iter = build_dataset_iter("train", fields, opt) + train_iterables = [] + for train_id in opt.data_ids: + if train_id: + shard_base = "train_" + train_id + else: + shard_base = "train" + iterable = build_dataset_iter(shard_base, fields, opt, multi=True) + train_iterables.append(iterable) + + train_iter = MultipleDatasetIterator(train_iterables, device_id, opt) + valid_iter = build_dataset_iter( "valid", fields, opt, is_train=False) diff --git a/onmt/utils/parse.py b/onmt/utils/parse.py index 42492d1c93..d01ba74c01 100644 --- a/onmt/utils/parse.py +++ b/onmt/utils/parse.py @@ -100,6 +100,8 @@ def validate_train_opts(cls, opt): raise AssertionError( "-gpu_ranks should have master(=0) rank " "unless -world_size is greater than len(gpu_ranks).") + assert len(opt.data_ids) == len(opt.data_weights), \ + "Please check -data_ids and -data_weights options!" assert len(opt.dropout) == len(opt.dropout_steps), \ "Number of dropout values must match number of accum_steps" @@ -118,9 +120,14 @@ def validate_preprocess_args(cls, opt): "-shuffle is not implemented. Please shuffle \ your data before pre-processing." - assert os.path.isfile(opt.train_src) \ - and os.path.isfile(opt.train_tgt), \ - "Please check path of your train src and tgt files!" + assert len(opt.train_src) == len(opt.train_tgt), \ + "Please provide same number of src and tgt train files!" + + assert len(opt.train_src) == len(opt.train_ids), \ + "Please provide proper -train_ids for your data!" + + for file in opt.train_src + opt.train_tgt: + assert os.path.isfile(file), "Please check path of %s" % file assert not opt.valid_src or os.path.isfile(opt.valid_src), \ "Please check path of your valid src file!" diff --git a/preprocess.py b/preprocess.py index 65f86382bb..746a89dd15 100755 --- a/preprocess.py +++ b/preprocess.py @@ -9,18 +9,21 @@ import gc import torch from functools import partial +from collections import Counter, defaultdict from onmt.utils.logging import init_logger, logger from onmt.utils.misc import split_corpus import onmt.inputters as inputters import onmt.opts as opts from onmt.utils.parse import ArgumentParser +from onmt.inputters.inputter import _build_fields_vocab,\ + _load_vocab def check_existing_pt_files(opt): """ Check if there are existing .pt files to avoid overwriting them """ pattern = opt.save_data + '.{}*.pt' - for t in ['train', 'valid', 'vocab']: + for t in ['train', 'valid']: path = pattern.format(t) if glob.glob(path): sys.stderr.write("Please backup existing pt files: %s, " @@ -32,51 +35,107 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt): assert corpus_type in ['train', 'valid'] if corpus_type == 'train': - src = opt.train_src - tgt = opt.train_tgt + counters = defaultdict(Counter) + srcs = opt.train_src + tgts = opt.train_tgt + ids = opt.train_ids else: - src = opt.valid_src - tgt = opt.valid_tgt - - logger.info("Reading source and target files: %s %s." % (src, tgt)) - - src_shards = split_corpus(src, opt.shard_size) - tgt_shards = split_corpus(tgt, opt.shard_size) - shard_pairs = zip(src_shards, tgt_shards) - dataset_paths = [] - if (corpus_type == "train" or opt.filter_valid) and tgt is not None: - filter_pred = partial( - inputters.filter_example, use_src_len=opt.data_type == "text", - max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) - else: - filter_pred = None - for i, (src_shard, tgt_shard) in enumerate(shard_pairs): - assert len(src_shard) == len(tgt_shard) - logger.info("Building shard %d." % i) - dataset = inputters.Dataset( - fields, - readers=[src_reader, tgt_reader] if tgt_reader else [src_reader], - data=([("src", src_shard), ("tgt", tgt_shard)] - if tgt_reader else [("src", src_shard)]), - dirs=[opt.src_dir, None] if tgt_reader else [opt.src_dir], - sort_key=inputters.str2sortkey[opt.data_type], - filter_pred=filter_pred - ) - - data_path = "{:s}.{:s}.{:d}.pt".format(opt.save_data, corpus_type, i) - dataset_paths.append(data_path) - - logger.info(" * saving %sth %s data shard to %s." - % (i, corpus_type, data_path)) - - dataset.save(data_path) - - del dataset.examples - gc.collect() - del dataset - gc.collect() - - return dataset_paths + srcs = [opt.valid_src] + tgts = [opt.valid_tgt] + ids = [None] + + for src, tgt, maybe_id in zip(srcs, tgts, ids): + logger.info("Reading source and target files: %s %s." % (src, tgt)) + + src_shards = split_corpus(src, opt.shard_size) + tgt_shards = split_corpus(tgt, opt.shard_size) + shard_pairs = zip(src_shards, tgt_shards) + dataset_paths = [] + if (corpus_type == "train" or opt.filter_valid) and tgt is not None: + filter_pred = partial( + inputters.filter_example, use_src_len=opt.data_type == "text", + max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length) + else: + filter_pred = None + + if corpus_type == "train": + if opt.src_vocab: + try: + logger.info("Using existing vocabulary...") + src_vocab = torch.load(opt.src_vocab) + except torch.serialization.pickle.UnpicklingError: + logger.info("Building vocab from text file...") + + src_vocab, src_vocab_size = _load_vocab( + opt.src_vocab, "src", counters, + opt.src_words_min_frequency) + else: + src_vocab = None + + if opt.tgt_vocab: + tgt_vocab, tgt_vocab_size = _load_vocab( + opt.tgt_vocab, "tgt", counters, + opt.tgt_words_min_frequency) + else: + tgt_vocab = None + + for i, (src_shard, tgt_shard) in enumerate(shard_pairs): + assert len(src_shard) == len(tgt_shard) + logger.info("Building shard %d." % i) + dataset = inputters.Dataset( + fields, + readers=([src_reader, tgt_reader] + if tgt_reader else [src_reader]), + data=([("src", src_shard), ("tgt", tgt_shard)] + if tgt_reader else [("src", src_shard)]), + dirs=([opt.src_dir, None] + if tgt_reader else [opt.src_dir]), + sort_key=inputters.str2sortkey[opt.data_type], + filter_pred=filter_pred + ) + if corpus_type == "train": + for ex in dataset.examples: + for name, field in fields.items(): + try: + f_iter = iter(field) + except TypeError: + f_iter = [(name, field)] + all_data = [getattr(ex, name, None)] + else: + all_data = getattr(ex, name) + for (sub_n, sub_f), fd in zip( + f_iter, all_data): + has_vocab = (sub_n == 'src' and src_vocab) or \ + (sub_n == 'tgt' and tgt_vocab) + if sub_f.sequential and not has_vocab: + val = fd + counters[sub_n].update(val) + if maybe_id: + shard_base = corpus_type + "_" + maybe_id + else: + shard_base = corpus_type + data_path = "{:s}.{:s}.{:d}.pt".\ + format(opt.save_data, shard_base, i) + dataset_paths.append(data_path) + + logger.info(" * saving %sth %s data shard to %s." + % (i, shard_base, data_path)) + + dataset.save(data_path) + + del dataset.examples + gc.collect() + del dataset + gc.collect() + + if corpus_type == "train": + fields = _build_fields_vocab( + fields, counters, opt.data_type, + opt.share_vocab, opt.vocab_size_multiple, + opt.src_vocab_size, opt.src_words_min_frequency, + opt.tgt_vocab_size, opt.tgt_words_min_frequency) + vocab_path = opt.save_data + '.vocab.pt' + torch.save(fields, vocab_path) def build_save_vocab(train_dataset, fields, opt): @@ -86,7 +145,6 @@ def build_save_vocab(train_dataset, fields, opt): opt.tgt_vocab, opt.tgt_vocab_size, opt.tgt_words_min_frequency, vocab_size_multiple=opt.vocab_size_multiple ) - vocab_path = opt.save_data + '.vocab.pt' torch.save(fields, vocab_path) @@ -110,9 +168,12 @@ def main(opt): init_logger(opt.log_file) logger.info("Extracting features...") - src_nfeats = count_features(opt.train_src) if opt.data_type == 'text' \ - else 0 - tgt_nfeats = count_features(opt.train_tgt) # tgt always text so far + src_nfeats = 0 + tgt_nfeats = 0 + for src, tgt in zip(opt.train_src, opt.train_tgt): + src_nfeats += count_features(src) if opt.data_type == 'text' \ + else 0 + tgt_nfeats += count_features(tgt) # tgt always text so far logger.info(" * number of source features: %d." % src_nfeats) logger.info(" * number of target features: %d." % tgt_nfeats) @@ -129,16 +190,13 @@ def main(opt): tgt_reader = inputters.str2reader["text"].from_opt(opt) logger.info("Building & saving training data...") - train_dataset_files = build_save_dataset( + build_save_dataset( 'train', fields, src_reader, tgt_reader, opt) if opt.valid_src and opt.valid_tgt: logger.info("Building & saving validation data...") build_save_dataset('valid', fields, src_reader, tgt_reader, opt) - logger.info("Building & saving vocabulary...") - build_save_vocab(train_dataset_files, fields, opt) - def _get_parser(): parser = ArgumentParser(description='preprocess.py')