Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Preprocessing: faster build vocab + multiple weighted datasets #1413

Merged
merged 15 commits into from
May 16, 2019
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 5 additions & 5 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ script:
# test nmt preprocessing
- python preprocess.py -train_src data/src-train.txt -train_tgt data/tgt-train.txt -valid_src data/src-val.txt -valid_tgt data/tgt-val.txt -save_data /tmp/data -src_vocab_size 1000 -tgt_vocab_size 1000 && rm -rf /tmp/data*.pt
# test im2text preprocessing
- python preprocess.py -data_type img -shard_size 3 -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-train.txt -train_tgt /tmp/im2text/tgt-train.txt -valid_src /tmp/im2text/src-val.txt -valid_tgt /tmp/im2text/tgt-val.txt -save_data /tmp/im2text/data && rm -rf /tmp/im2text/data*.pt
- python preprocess.py -data_type img -shard_size 100 -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-train.txt -train_tgt /tmp/im2text/tgt-train.txt -valid_src /tmp/im2text/src-val.txt -valid_tgt /tmp/im2text/tgt-val.txt -save_data /tmp/im2text/data && rm -rf /tmp/im2text/data*.pt
# test speech2text preprocessing
- python preprocess.py -data_type audio -shard_size 300 -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-train.txt -train_tgt /tmp/speech/tgt-train.txt -valid_src /tmp/speech/src-val.txt -valid_tgt /tmp/speech/tgt-val.txt -save_data /tmp/speech/data && rm -rf /tmp/speech/data*.pt
# test nmt translation
Expand All @@ -43,14 +43,14 @@ script:
# test speech2text translation
- head /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python translate.py -data_type audio -src_dir /tmp/speech/an4_dataset -model /tmp/test_model_speech.pt -src /tmp/speech/src-val-head.txt -tgt /tmp/speech/tgt-val-head.txt -verbose -out /tmp/speech/trans; diff /tmp/speech/tgt-val-head.txt /tmp/speech/trans
# test nmt preprocessing and training
- head data/src-val.txt > /tmp/src-val.txt; head data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/q*.pt
- head -500 data/src-val.txt > /tmp/src-val.txt; head -500 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/q*.pt
# test nmt preprocessing w/ sharding and training w/copy
- head data/src-val.txt > /tmp/src-val.txt; head data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 1 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 && rm -rf /tmp/q*.pt
- head -50 data/src-val.txt > /tmp/src-val.txt; head -50 data/tgt-val.txt > /tmp/tgt-val.txt; python preprocess.py -train_src /tmp/src-val.txt -train_tgt /tmp/tgt-val.txt -valid_src /tmp/src-val.txt -valid_tgt /tmp/tgt-val.txt -shard_size 25 -dynamic_dict -save_data /tmp/q -src_vocab_size 1000 -tgt_vocab_size 1000; python train.py -data /tmp/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -copy_attn -train_steps 10 && rm -rf /tmp/q*.pt

# test im2text preprocessing and training
- head /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/im2text/q*.pt
- head -50 /tmp/im2text/src-val.txt > /tmp/im2text/src-val-head.txt; head -50 /tmp/im2text/tgt-val.txt > /tmp/im2text/tgt-val-head.txt; python preprocess.py -data_type img -src_dir /tmp/im2text/images -train_src /tmp/im2text/src-val-head.txt -train_tgt /tmp/im2text/tgt-val-head.txt -valid_src /tmp/im2text/src-val-head.txt -valid_tgt /tmp/im2text/tgt-val-head.txt -save_data /tmp/im2text/q -tgt_seq_length 100; python train.py -model_type img -data /tmp/im2text/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/im2text/q*.pt
# test speech2text preprocessing and training
- head /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python preprocess.py -data_type audio -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-val-head.txt -train_tgt /tmp/speech/tgt-val-head.txt -valid_src /tmp/speech/src-val-head.txt -valid_tgt /tmp/speech/tgt-val-head.txt -save_data /tmp/speech/q; python train.py -model_type audio -data /tmp/speech/q -rnn_size 2 -batch_size 10 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/speech/q*.pt
- head -100 /tmp/speech/src-val.txt > /tmp/speech/src-val-head.txt; head -100 /tmp/speech/tgt-val.txt > /tmp/speech/tgt-val-head.txt; python preprocess.py -data_type audio -src_dir /tmp/speech/an4_dataset -train_src /tmp/speech/src-val-head.txt -train_tgt /tmp/speech/tgt-val-head.txt -valid_src /tmp/speech/src-val-head.txt -valid_tgt /tmp/speech/tgt-val-head.txt -save_data /tmp/speech/q; python train.py -model_type audio -data /tmp/speech/q -rnn_size 2 -batch_size 2 -word_vec_size 5 -report_every 5 -rnn_size 10 -train_steps 10 && rm -rf /tmp/speech/q*.pt
# test nmt translation
- python translate.py -model onmt/tests/test_model2.pt -src data/morph/src.valid -verbose -batch_size 10 -beam_size 10 -tgt data/morph/tgt.valid -out /tmp/trans; diff data/morph/tgt.valid /tmp/trans
# test nmt translation with random sampling
Expand Down
209 changes: 163 additions & 46 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import codecs
import math
import random

from collections import Counter, defaultdict
from itertools import chain, cycle
Expand Down Expand Up @@ -308,6 +309,42 @@ def _build_fv_from_multifield(multifield, counters, build_fv_args,
logger.info(" * %s vocab size: %d." % (name, len(field.vocab)))


def _build_fields_vocab(fields, counters, data_type, share_vocab,
vocab_size_multiple,
src_vocab_size, src_words_min_frequency,
tgt_vocab_size, tgt_words_min_frequency):
build_fv_args = defaultdict(dict)
build_fv_args["src"] = dict(
max_size=src_vocab_size, min_freq=src_words_min_frequency)
build_fv_args["tgt"] = dict(
max_size=tgt_vocab_size, min_freq=tgt_words_min_frequency)
tgt_multifield = fields["tgt"]
_build_fv_from_multifield(
tgt_multifield,
counters,
build_fv_args,
size_multiple=vocab_size_multiple if not share_vocab else 1)
if data_type == 'text':
src_multifield = fields["src"]
_build_fv_from_multifield(
src_multifield,
counters,
build_fv_args,
size_multiple=vocab_size_multiple if not share_vocab else 1)
if share_vocab:
# `tgt_vocab_size` is ignored when sharing vocabularies
logger.info(" * merging src and tgt vocab...")
src_field = src_multifield.base_field
tgt_field = tgt_multifield.base_field
_merge_field_vocabs(
src_field, tgt_field, vocab_size=src_vocab_size,
min_freq=src_words_min_frequency,
vocab_size_multiple=vocab_size_multiple)
logger.info(" * merged vocab size: %d." % len(src_field.vocab))

return fields


def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab_path, src_vocab_size, src_words_min_frequency,
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency,
Expand Down Expand Up @@ -392,34 +429,12 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
del dataset
gc.collect()

build_fv_args = defaultdict(dict)
build_fv_args["src"] = dict(
max_size=src_vocab_size, min_freq=src_words_min_frequency)
build_fv_args["tgt"] = dict(
max_size=tgt_vocab_size, min_freq=tgt_words_min_frequency)
tgt_multifield = fields["tgt"]
_build_fv_from_multifield(
tgt_multifield,
counters,
build_fv_args,
size_multiple=vocab_size_multiple if not share_vocab else 1)
if data_type == 'text':
src_multifield = fields["src"]
_build_fv_from_multifield(
src_multifield,
counters,
build_fv_args,
size_multiple=vocab_size_multiple if not share_vocab else 1)
if share_vocab:
# `tgt_vocab_size` is ignored when sharing vocabularies
logger.info(" * merging src and tgt vocab...")
src_field = src_multifield.base_field
tgt_field = tgt_multifield.base_field
_merge_field_vocabs(
src_field, tgt_field, vocab_size=src_vocab_size,
min_freq=src_words_min_frequency,
vocab_size_multiple=vocab_size_multiple)
logger.info(" * merged vocab size: %d." % len(src_field.vocab))
fields = _build_fields_vocab(
fields, counters, data_type,
share_vocab, vocab_size_multiple,
src_vocab_size, src_words_min_frequency,
tgt_vocab_size, tgt_words_min_frequency)

return fields # is the return necessary?


Expand Down Expand Up @@ -497,29 +512,42 @@ def batch_size_fn(new, count, sofar):
yield minibatch


def _pool(data, batch_size, batch_size_fn, batch_size_multiple,
sort_key, random_shuffler):
for p in torchtext.data.batch(
data, batch_size * 500,
batch_size_fn=batch_size_fn):
p_batch = batch_iter(
sorted(p, key=sort_key),
batch_size,
batch_size_fn=batch_size_fn,
batch_size_multiple=batch_size_multiple)
for b in list(p_batch):
yield b


class OrderedIterator(torchtext.data.Iterator):

def __init__(self,
dataset,
batch_size,
batch_size_multiple=1,
yield_raw_example=False,
**kwargs):
super(OrderedIterator, self).__init__(dataset, batch_size, **kwargs)
self.batch_size_multiple = batch_size_multiple
self.yield_raw_example = yield_raw_example
self.dataset = dataset

def create_batches(self):
if self.train:
def _pool(data, random_shuffler):
for p in torchtext.data.batch(data, self.batch_size * 100):
p_batch = batch_iter(
sorted(p, key=self.sort_key),
self.batch_size,
batch_size_fn=self.batch_size_fn,
batch_size_multiple=self.batch_size_multiple)
for b in random_shuffler(list(p_batch)):
yield b

self.batches = _pool(self.data(), self.random_shuffler)
self.batches = _pool(
self.data(),
self.batch_size,
self.batch_size_fn,
self.batch_size_multiple,
self.sort_key,
self.random_shuffler)
else:
self.batches = []
for b in batch_iter(
Expand All @@ -529,6 +557,85 @@ def _pool(data, random_shuffler):
batch_size_multiple=self.batch_size_multiple):
self.batches.append(sorted(b, key=self.sort_key))

def __iter__(self):
"""
Extended version of the definition in torchtext.data.Iterator.
Added yield_raw_example behaviour to yield a torchtext.data.Example
instead of a torchtext.data.Batch object.
"""
while True:
self.init_epoch()
for idx, minibatch in enumerate(self.batches):
# fast-forward if loaded from state
if self._iterations_this_epoch > idx:
continue
self.iterations += 1
self._iterations_this_epoch += 1
if self.sort_within_batch:
# NOTE: `rnn.pack_padded_sequence` requires that a
# minibatch be sorted by decreasing order, which
# requires reversing relative to typical sort keys
if self.sort:
minibatch.reverse()
else:
minibatch.sort(key=self.sort_key, reverse=True)
if self.yield_raw_example:
yield minibatch[0]
else:
yield torchtext.data.Batch(
minibatch,
self.dataset,
self.device)
if not self.repeat:
return


class MultipleDatasetIterator(object):
"""
This takes a list of iterable objects (DatasetLazyIter) and their
respective weights, and yields a batch in the wanted proportions.
"""
def __init__(self,
iterables,
device,
opt):
self.index = -1
self.iterators = [iter(iterable) for iterable in iterables]
self.iterables = iterables
self.weights = opt.data_weights
self.batch_size = opt.batch_size
self.batch_size_fn = max_tok_len \
if opt.batch_type == "tokens" else None
self.batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
self.device = "cuda" if device >= 0 else "cpu"
# Temporarily load one shard to retrieve sort_key for data_type
temp_dataset = torch.load(self.iterables[0]._paths[0])
self.sort_key = temp_dataset.sort_key
del temp_dataset

def _iter_datasets(self):
for weight in self.weights:
self.index = (self.index + 1) % len(self.iterators)
for i in range(weight):
yield self.iterators[self.index]

def _iter_examples(self):
for iterator in cycle(self._iter_datasets()):
yield next(iterator)

def __iter__(self):
for minibatch in _pool(
self._iter_examples(),
self.batch_size,
self.batch_size_fn,
self.batch_size_multiple,
self.sort_key,
random.shuffle):
minibatch = sorted(minibatch, key=self.sort_key, reverse=True)
yield torchtext.data.Batch(minibatch,
self.iterables[0].dataset,
self.device)


class DatasetLazyIter(object):
"""Yield data from sharded dataset files.
Expand All @@ -545,7 +652,7 @@ class DatasetLazyIter(object):

def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
batch_size_multiple, device, is_train, repeat=True,
num_batches_multiple=1):
num_batches_multiple=1, yield_raw_example=False):
self._paths = dataset_paths
self.fields = fields
self.batch_size = batch_size
Expand All @@ -555,6 +662,7 @@ def __init__(self, dataset_paths, fields, batch_size, batch_size_fn,
self.is_train = is_train
self.repeat = repeat
self.num_batches_multiple = num_batches_multiple
self.yield_raw_example = yield_raw_example

def _iter_dataset(self, path):
cur_dataset = torch.load(path)
Expand All @@ -570,9 +678,11 @@ def _iter_dataset(self, path):
train=self.is_train,
sort=False,
sort_within_batch=True,
repeat=False
repeat=False,
yield_raw_example=self.yield_raw_example
)
for batch in cur_iter:
self.dataset = cur_iter.dataset
yield batch

cur_dataset.examples = None
Expand Down Expand Up @@ -625,7 +735,7 @@ def max_tok_len(new, count, sofar):
return max(src_elements, tgt_elements)


def build_dataset_iter(corpus_type, fields, opt, is_train=True):
def build_dataset_iter(corpus_type, fields, opt, is_train=True, multi=False):
"""
This returns user-defined train/validate data iterator for the trainer
to iterate over. We implement simple ordered iterator strategy here,
Expand All @@ -635,9 +745,15 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True):
glob.glob(opt.data + '.' + corpus_type + '*.pt')))
if not dataset_paths:
return None
batch_size = opt.batch_size if is_train else opt.valid_batch_size
batch_fn = max_tok_len if is_train and opt.batch_type == "tokens" else None
batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1
if multi:
batch_size = 1
batch_fn = None
batch_size_multiple = 1
else:
batch_size = opt.batch_size if is_train else opt.valid_batch_size
batch_fn = max_tok_len \
if is_train and opt.batch_type == "tokens"else None
vince62s marked this conversation as resolved.
Show resolved Hide resolved
batch_size_multiple = 8 if opt.model_dtype == "fp16" else 1

device = "cuda" if opt.gpu_ranks else "cpu"

Expand All @@ -650,4 +766,5 @@ def build_dataset_iter(corpus_type, fields, opt, is_train=True):
device,
is_train,
repeat=not opt.single_pass,
num_batches_multiple=max(opt.accum_count) * opt.world_size)
num_batches_multiple=max(opt.accum_count) * opt.world_size,
yield_raw_example=multi)
16 changes: 12 additions & 4 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,10 +189,12 @@ def preprocess_opts(parser):
help="Type of the source input. "
"Options are [text|img|audio].")

group.add('--train_src', '-train_src', required=True,
help="Path to the training source data")
group.add('--train_tgt', '-train_tgt', required=True,
help="Path to the training target data")
group.add('--train_src', '-train_src', required=True, nargs='+',
help="Path(s) to the training source data")
group.add('--train_tgt', '-train_tgt', required=True, nargs='+',
help="Path(s) to the training target data")
group.add('--train_ids', '-train_ids', nargs='+', default=[None],
help="ids to name training shards, used for corpus weighting")
group.add('--valid_src', '-valid_src',
help="Path to the validation source data")
group.add('--valid_tgt', '-valid_tgt',
Expand Down Expand Up @@ -308,6 +310,12 @@ def train_opts(parser):
help='Path prefix to the ".train.pt" and '
'".valid.pt" file path from preprocess.py')

group.add('--data_ids', '-data_ids', nargs='+', default=[None],
help="In case there are several corpora.")
group.add('--data_weights', '-data_weights', type=int, nargs='+',
default=[1], help="""Weights of different corpora,
should follow the same order as in -data_ids.""")

group.add('--save_model', '-save_model', default='model',
help="Model filename (the model will be saved as "
"<save_model>_N.pt where N is the number "
Expand Down
Loading