Skip to content

Commit

Permalink
Use multiprocessing to speed up preprocessing (OpenNMT#1569)
Browse files Browse the repository at this point in the history
* wip multiprocessing preprocess
* add function maybe_load_vocab in preprocess
* do not build counter for audio src
* more detailed check for existing shards
* better logs
  • Loading branch information
francoishernandez authored and vince62s committed Sep 23, 2019
1 parent 8cd68bc commit cd97a85
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 98 deletions.
3 changes: 3 additions & 0 deletions onmt/opts.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ def preprocess_opts(parser):
"shard_size>0 means segment dataset into multiple shards, "
"each shard has shard_size samples")

group.add('--num_threads', '-num_threads', type=int, default=1,
help="Number of shards to build in parallel.")

group.add('--overwrite', '-overwrite', action="store_true",
help="Overwrite existing shards if any.")

Expand Down
259 changes: 161 additions & 98 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,8 @@
"""
import codecs
import glob
import sys
import gc
import torch
from functools import partial
from collections import Counter, defaultdict

from onmt.utils.logging import init_logger, logger
Expand All @@ -19,16 +17,112 @@
from onmt.inputters.inputter import _build_fields_vocab,\
_load_vocab

from functools import partial
from multiprocessing import Pool


def check_existing_pt_files(opt):
def check_existing_pt_files(opt, corpus_type, ids, existing_fields):
""" Check if there are existing .pt files to avoid overwriting them """
pattern = opt.save_data + '.{}*.pt'
for t in ['train', 'valid']:
path = pattern.format(t)
if glob.glob(path):
sys.stderr.write("Please backup existing pt files: %s, "
"to avoid overwriting them!\n" % path)
sys.exit(1)
existing_shards = []
for maybe_id in ids:
if maybe_id:
shard_base = corpus_type + "_" + maybe_id
else:
shard_base = corpus_type
pattern = opt.save_data + '.{}.*.pt'.format(shard_base)
if glob.glob(pattern):
if opt.overwrite:
maybe_overwrite = ("will be overwritten because "
"`-overwrite` option is set.")
else:
maybe_overwrite = ("won't be overwritten, pass the "
"`-overwrite` option if you want to.")
logger.warning("Shards for corpus {} already exist, {}"
.format(shard_base, maybe_overwrite))
existing_shards += [maybe_id]
return existing_shards


def process_one_shard(corpus_params, params):
corpus_type, fields, src_reader, tgt_reader, opt, existing_fields,\
src_vocab, tgt_vocab = corpus_params
i, (src_shard, tgt_shard, maybe_id, filter_pred) = params
# create one counter per shard
sub_sub_counter = defaultdict(Counter)
assert len(src_shard) == len(tgt_shard)
logger.info("Building shard %d." % i)
dataset = inputters.Dataset(
fields,
readers=([src_reader, tgt_reader]
if tgt_reader else [src_reader]),
data=([("src", src_shard), ("tgt", tgt_shard)]
if tgt_reader else [("src", src_shard)]),
dirs=([opt.src_dir, None]
if tgt_reader else [opt.src_dir]),
sort_key=inputters.str2sortkey[opt.data_type],
filter_pred=filter_pred
)
if corpus_type == "train" and existing_fields is None:
for ex in dataset.examples:
for name, field in fields.items():
if ((opt.data_type == "audio") and (name == "src")):
continue
try:
f_iter = iter(field)
except TypeError:
f_iter = [(name, field)]
all_data = [getattr(ex, name, None)]
else:
all_data = getattr(ex, name)
for (sub_n, sub_f), fd in zip(
f_iter, all_data):
has_vocab = (sub_n == 'src' and
src_vocab is not None) or \
(sub_n == 'tgt' and
tgt_vocab is not None)
if (hasattr(sub_f, 'sequential')
and sub_f.sequential and not has_vocab):
val = fd
sub_sub_counter[sub_n].update(val)
if maybe_id:
shard_base = corpus_type + "_" + maybe_id
else:
shard_base = corpus_type
data_path = "{:s}.{:s}.{:d}.pt".\
format(opt.save_data, shard_base, i)

logger.info(" * saving %sth %s data shard to %s."
% (i, shard_base, data_path))

dataset.save(data_path)

del dataset.examples
gc.collect()
del dataset
gc.collect()

return sub_sub_counter


def maybe_load_vocab(corpus_type, counters, opt):
src_vocab = None
tgt_vocab = None
existing_fields = None
if corpus_type == "train":
if opt.src_vocab != "":
try:
logger.info("Using existing vocabulary...")
existing_fields = torch.load(opt.src_vocab)
except torch.serialization.pickle.UnpicklingError:
logger.info("Building vocab from text file...")
src_vocab, src_vocab_size = _load_vocab(
opt.src_vocab, "src", counters,
opt.src_words_min_frequency)
if opt.tgt_vocab != "":
tgt_vocab, tgt_vocab_size = _load_vocab(
opt.tgt_vocab, "tgt", counters,
opt.tgt_words_min_frequency)
return src_vocab, tgt_vocab, existing_fields


def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
Expand All @@ -39,97 +133,67 @@ def build_save_dataset(corpus_type, fields, src_reader, tgt_reader, opt):
srcs = opt.train_src
tgts = opt.train_tgt
ids = opt.train_ids
else:
elif corpus_type == 'valid':
counters = None
srcs = [opt.valid_src]
tgts = [opt.valid_tgt]
ids = [None]

for src, tgt, maybe_id in zip(srcs, tgts, ids):
logger.info("Reading source and target files: %s %s." % (src, tgt))

src_shards = split_corpus(src, opt.shard_size)
tgt_shards = split_corpus(tgt, opt.shard_size)
shard_pairs = zip(src_shards, tgt_shards)
dataset_paths = []
if (corpus_type == "train" or opt.filter_valid) and tgt is not None:
filter_pred = partial(
inputters.filter_example, use_src_len=opt.data_type == "text",
max_src_len=opt.src_seq_length, max_tgt_len=opt.tgt_seq_length)
else:
filter_pred = None

if corpus_type == "train":
existing_fields = None
if opt.src_vocab != "":
try:
logger.info("Using existing vocabulary...")
existing_fields = torch.load(opt.src_vocab)
except torch.serialization.pickle.UnpicklingError:
logger.info("Building vocab from text file...")
src_vocab, src_vocab_size = _load_vocab(
opt.src_vocab, "src", counters,
opt.src_words_min_frequency)
src_vocab, tgt_vocab, existing_fields = maybe_load_vocab(
corpus_type, counters, opt)

existing_shards = check_existing_pt_files(
opt, corpus_type, ids, existing_fields)

# every corpus has shards, no new one
if existing_shards == ids and not opt.overwrite:
return

def shard_iterator(srcs, tgts, ids, existing_shards,
existing_fields, corpus_type, opt):
"""
Builds a single iterator yielding every shard of every corpus.
"""
for src, tgt, maybe_id in zip(srcs, tgts, ids):
if maybe_id in existing_shards:
if opt.overwrite:
logger.warning("Overwrite shards for corpus {}"
.format(maybe_id))
else:
if corpus_type == "train":
assert existing_fields is not None,\
("A 'vocab.pt' file should be passed to "
"`-src_vocab` when adding a corpus to "
"a set of already existing shards.")
logger.warning("Ignore corpus {} because "
"shards already exist"
.format(maybe_id))
continue
if ((corpus_type == "train" or opt.filter_valid)
and tgt is not None):
filter_pred = partial(
inputters.filter_example,
use_src_len=opt.data_type == "text",
max_src_len=opt.src_seq_length,
max_tgt_len=opt.tgt_seq_length)
else:
src_vocab = None

if opt.tgt_vocab != "":
tgt_vocab, tgt_vocab_size = _load_vocab(
opt.tgt_vocab, "tgt", counters,
opt.tgt_words_min_frequency)
else:
tgt_vocab = None

for i, (src_shard, tgt_shard) in enumerate(shard_pairs):
assert len(src_shard) == len(tgt_shard)
logger.info("Building shard %d." % i)
dataset = inputters.Dataset(
fields,
readers=([src_reader, tgt_reader]
if tgt_reader else [src_reader]),
data=([("src", src_shard), ("tgt", tgt_shard)]
if tgt_reader else [("src", src_shard)]),
dirs=([opt.src_dir, None]
if tgt_reader else [opt.src_dir]),
sort_key=inputters.str2sortkey[opt.data_type],
filter_pred=filter_pred
)
if corpus_type == "train" and existing_fields is None:
for ex in dataset.examples:
for name, field in fields.items():
try:
f_iter = iter(field)
except TypeError:
f_iter = [(name, field)]
all_data = [getattr(ex, name, None)]
else:
all_data = getattr(ex, name)
for (sub_n, sub_f), fd in zip(
f_iter, all_data):
has_vocab = (sub_n == 'src' and
src_vocab is not None) or \
(sub_n == 'tgt' and
tgt_vocab is not None)
if (hasattr(sub_f, 'sequential')
and sub_f.sequential and not has_vocab):
val = fd
counters[sub_n].update(val)
if maybe_id:
shard_base = corpus_type + "_" + maybe_id
else:
shard_base = corpus_type
data_path = "{:s}.{:s}.{:d}.pt".\
format(opt.save_data, shard_base, i)
dataset_paths.append(data_path)

logger.info(" * saving %sth %s data shard to %s."
% (i, shard_base, data_path))

dataset.save(data_path)

del dataset.examples
gc.collect()
del dataset
gc.collect()
filter_pred = None
src_shards = split_corpus(src, opt.shard_size)
tgt_shards = split_corpus(tgt, opt.shard_size)
for i, (ss, ts) in enumerate(zip(src_shards, tgt_shards)):
yield (i, (ss, ts, maybe_id, filter_pred))

shard_iter = shard_iterator(srcs, tgts, ids, existing_shards,
existing_fields, corpus_type, opt)

with Pool(opt.num_threads) as p:
dataset_params = (corpus_type, fields, src_reader, tgt_reader,
opt, existing_fields, src_vocab, tgt_vocab)
func = partial(process_one_shard, dataset_params)
for sub_counter in p.imap(func, shard_iter):
if sub_counter is not None:
for key, value in sub_counter.items():
counters[key].update(value)

if corpus_type == "train":
vocab_path = opt.save_data + '.vocab.pt'
Expand Down Expand Up @@ -169,10 +233,9 @@ def count_features(path):
def main(opt):
ArgumentParser.validate_preprocess_args(opt)
torch.manual_seed(opt.seed)
if not(opt.overwrite):
check_existing_pt_files(opt)

init_logger(opt.log_file)

logger.info("Extracting features...")

src_nfeats = 0
Expand Down

0 comments on commit cd97a85

Please sign in to comment.