Skip to content

Commit

Permalink
Revert "Revert "Fix preprocess sharding and implement train side lazy…
Browse files Browse the repository at this point in the history
… dataset load""
  • Loading branch information
srush authored Jan 14, 2018
1 parent 6d54f31 commit e72c92b
Show file tree
Hide file tree
Showing 8 changed files with 221 additions and 122 deletions.
6 changes: 6 additions & 0 deletions onmt/Trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,6 +199,9 @@ def gradient_accumulation(truebatch_, total_stats_,
self.optim.step()

for i, batch_ in enumerate(self.train_iter):
cur_dataset = self.train_iter.get_cur_dataset()
self.train_loss.cur_dataset = cur_dataset

truebatch.append(batch_)
accum += 1
if self.normalization is "tokens":
Expand Down Expand Up @@ -243,6 +246,9 @@ def validate(self):
stats = Statistics()

for batch in self.valid_iter:
cur_dataset = self.valid_iter.get_cur_dataset()
self.valid_loss.cur_dataset = cur_dataset

src = onmt.io.make_features(batch, 'src', self.data_type)
if self.data_type == 'text':
_, src_lengths = batch.src
Expand Down
18 changes: 0 additions & 18 deletions onmt/io/DatasetBase.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,6 @@ def load_fields(self, vocab_dict):
self.fields = dict([(k, f) for (k, f) in fields.items()
if k in self.examples[0].__dict__])

def collapse_copy_scores(self, scores, batch, tgt_vocab):
"""
Given scores from an expanded dictionary
corresponeding to a batch, sums together copies,
with a dictionary word when it is ambigious.
"""
offset = len(tgt_vocab)
for b in range(batch.batch_size):
index = batch.indices.data[b]
src_vocab = self.src_vocabs[index]
for i in range(1, len(src_vocab)):
sw = src_vocab.itos[i]
ti = tgt_vocab.stoi[sw]
if ti != 0:
scores[:, b, ti] += scores[:, b, offset + i]
scores[:, b, offset + i].fill_(1e-20)
return scores

@staticmethod
def coalesce_datasets(datasets):
"""Coalesce all dataset instances. """
Expand Down
62 changes: 48 additions & 14 deletions onmt/io/IO.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

from collections import Counter, defaultdict
from collections import Counter, defaultdict, OrderedDict
from itertools import count

import torch
Expand Down Expand Up @@ -216,12 +216,21 @@ def build_dataset(fields, data_type, src_path, tgt_path, src_dir=None,
return dataset


def build_vocab(train_datasets, data_type, share_vocab,
def _build_field_vocab(field, counter, **kwargs):
specials = list(OrderedDict.fromkeys(
tok for tok in [field.unk_token, field.pad_token, field.init_token,
field.eos_token]
if tok is not None))
field.vocab = field.vocab_cls(counter, specials=specials, **kwargs)


def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab_size, src_words_min_frequency,
tgt_vocab_size, tgt_words_min_frequency):
"""
Args:
train_datasets: a list of train dataset.
train_dataset_files: a list of train dataset pt file.
fields (dict): fields to build vocab for.
data_type: "text", "img" or "audio"?
share_vocab(bool): share source and target vocabulary?
src_vocab_size(int): size of the source vocabulary.
Expand All @@ -234,23 +243,48 @@ def build_vocab(train_datasets, data_type, share_vocab,
Returns:
Dict of Fields
"""
# All datasets have same fields, get the first one is OK.
fields = train_datasets[0].fields

fields["tgt"].build_vocab(*train_datasets, max_size=tgt_vocab_size,
min_freq=tgt_words_min_frequency)
for j in range(train_datasets[0].n_tgt_feats):
fields["tgt_feat_" + str(j)].build_vocab(*train_datasets)
counter = {}
for k in fields:
counter[k] = Counter()

for path in train_dataset_files:
dataset = torch.load(path)
for ex in dataset.examples:
for k in fields:
val = getattr(ex, k, None)
if val is not None and not fields[k].sequential:
val = [val]
counter[k].update(val)

_build_field_vocab(fields["tgt"], counter["tgt"],
max_size=tgt_vocab_size,
min_freq=tgt_words_min_frequency)
print(" * tgt vocab size: %d." % len(fields["tgt"].vocab))

# All datasets have same num of n_tgt_features,
# getting the last one is OK.
for j in range(dataset.n_tgt_feats):
key = "tgt_feat_" + str(j)
_build_field_vocab(fields[key], counter[key])
print(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

if data_type == 'text':
fields["src"].build_vocab(*train_datasets, max_size=src_vocab_size,
min_freq=src_words_min_frequency)
for j in range(train_datasets[0].n_src_feats):
fields["src_feat_" + str(j)].build_vocab(*train_datasets)
_build_field_vocab(fields["src"], counter["src"],
max_size=src_vocab_size,
min_freq=src_words_min_frequency)
print(" * src vocab size: %d." % len(fields["src"].vocab))

# All datasets have same num of n_src_features,
# getting the last one is OK.
for j in range(dataset.n_src_feats):
key = "src_feat_" + str(j)
_build_field_vocab(fields[key], counter[key])
print(" * %s vocab size: %d." % (key, len(fields[key].vocab)))

# Merge the input and output vocabularies.
if share_vocab:
# `tgt_vocab_size` is ignored when sharing vocabularies
print(" * merging src and tgt vocab...")
merged_vocab = merge_vocabs(
[fields["src"].vocab, fields["tgt"].vocab],
vocab_size=src_vocab_size)
Expand Down
21 changes: 20 additions & 1 deletion onmt/io/TextDataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,25 @@ def sort_key(self, ex):
""" Sort using length of source sentences. """
return len(ex.src)

@staticmethod
def collapse_copy_scores(scores, batch, tgt_vocab, src_vocabs):
"""
Given scores from an expanded dictionary
corresponeding to a batch, sums together copies,
with a dictionary word when it is ambigious.
"""
offset = len(tgt_vocab)
for b in range(batch.batch_size):
index = batch.indices.data[b]
src_vocab = src_vocabs[index]
for i in range(1, len(src_vocab)):
sw = src_vocab.itos[i]
ti = tgt_vocab.stoi[sw]
if ti != 0:
scores[:, b, ti] += scores[:, b, offset + i]
scores[:, b, offset + i].fill_(1e-20)
return scores

@staticmethod
def make_text_examples_nfeats_tpl(path, truncate, side):
"""
Expand Down Expand Up @@ -228,7 +247,7 @@ def get_num_features(corpus_file, side):

return num_feats

# Below are helper functions for intra-class use only.self.
# Below are helper functions for intra-class use only.
def _dynamic_dict(self, examples_iter):
for example in examples_iter:
src = example["src"]
Expand Down
10 changes: 6 additions & 4 deletions onmt/modules/CopyGenerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,12 +135,14 @@ class CopyGeneratorLossCompute(onmt.Loss.LossComputeBase):
"""
Copy Generator Loss Computation.
"""
def __init__(self, generator, tgt_vocab, dataset,
def __init__(self, generator, tgt_vocab,
force_copy, eps=1e-20):
super(CopyGeneratorLossCompute, self).__init__(
generator, tgt_vocab)

self.dataset = dataset
# We lazily load datasets when there are more than one, so postpone
# the setting of cur_dataset.
self.cur_dataset = None
self.force_copy = force_copy
self.criterion = CopyGeneratorCriterion(len(tgt_vocab), force_copy,
self.padding_idx)
Expand Down Expand Up @@ -177,9 +179,9 @@ def _compute_loss(self, batch, output, target, copy_attn, align):
loss = self.criterion(scores, align, target)

scores_data = scores.data.clone()
scores_data = self.dataset.collapse_copy_scores(
scores_data = onmt.io.TextDataset.collapse_copy_scores(
self._unbottle(scores_data, batch.batch_size),
batch, self.tgt_vocab)
batch, self.tgt_vocab, self.cur_dataset.src_vocabs)
scores_data = self._bottle(scores_data)

# Correct target copy token instead of <unk>
Expand Down
35 changes: 19 additions & 16 deletions preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def parse_args():
def build_save_text_dataset_in_shards(src_corpus, tgt_corpus, fields,
corpus_type, opt):
'''
Divide the big corpus into shards, and build dataset seperately.
Divide the big corpus into shards, and build dataset separately.
This is currently only for data_type=='text'.
The reason we do this is to avoid taking up too much memory due
Expand Down Expand Up @@ -80,6 +80,9 @@ def build_save_text_dataset_in_shards(src_corpus, tgt_corpus, fields,
"tgt", opt.max_shard_size,
assoc_iter=src_iter)

print(' * divide corpus into shards and build dataset separately'
'(shard_size = %d bytes).' % opt.max_shard_size)

index = 0
while not src_iter.hit_end():
index += 1
Expand All @@ -91,18 +94,14 @@ def build_save_text_dataset_in_shards(src_corpus, tgt_corpus, fields,
dynamic_dict=opt.dynamic_dict)

# We save fields in vocab.pt seperately, so make it empty.
saved_fields = dataset.fields
dataset.fields = []

pt_file = "{:s}.{:s}.{:d}.pt".format(
opt.save_data, corpus_type, index)
print(" * saving train data shard to %s." % pt_file)
torch.save(dataset, pt_file)

dataset.fields = saved_fields
ret_list.append(dataset)

if index == 1:
# Only one shard, strip the index in the filename.
os.rename(pt_file, opt.save_data + '.' + corpus_type + '.pt')
ret_list.append(pt_file)

return ret_list

Expand Down Expand Up @@ -141,17 +140,17 @@ def build_save_dataset(corpus_type, fields, opt):
window=opt.window)

# We save fields in vocab.pt seperately, so make it empty.
saved_fields = dataset.fields
dataset.fields = []

pt_file = "{:s}.{:s}.pt".format(opt.save_data, corpus_type)
print(" * saving train dataset to %s." % pt_file)
torch.save(dataset, pt_file)
dataset.fields = saved_fields

return [dataset]
return [pt_file]


def build_save_vocab(train_dataset, opt):
fields = onmt.io.build_vocab(train_dataset, opt.data_type,
def build_save_vocab(train_dataset, fields, opt):
fields = onmt.io.build_vocab(train_dataset, fields, opt.data_type,
opt.share_vocab,
opt.src_vocab_size,
opt.src_words_min_frequency,
Expand All @@ -166,16 +165,20 @@ def build_save_vocab(train_dataset, opt):
def main():
opt = parse_args()

print('Preparing for training ...')
print("Extracting features...")
src_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_src, 'src')
tgt_nfeats = onmt.io.get_num_features(opt.data_type, opt.train_tgt, 'tgt')
print(" * number of source features: %d." % src_nfeats)
print(" * number of target features: %d." % tgt_nfeats)

print("Loading Fields object...")
fields = onmt.io.get_fields(opt.data_type, src_nfeats, tgt_nfeats)

print("Building & saving training data...")
train_datasets = build_save_dataset('train', fields, opt)
train_dataset_files = build_save_dataset('train', fields, opt)

print("Building & saving vocabulary...")
build_save_vocab(train_datasets, opt)
build_save_vocab(train_dataset_files, fields, opt)

print("Building & saving validation data...")
build_save_dataset('valid', fields, opt)
Expand Down
4 changes: 2 additions & 2 deletions test/test_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ def __init__(self, *args, **kwargs):
def dataset_build(self, opt):
fields = onmt.io.get_fields("text", 0, 0)

trains = preprocess.build_save_dataset('train', fields, opt)
train_data_files = preprocess.build_save_dataset('train', fields, opt)

preprocess.build_save_vocab(trains, opt)
preprocess.build_save_vocab(train_data_files, fields, opt)

preprocess.build_save_dataset('valid', fields, opt)

Expand Down
Loading

0 comments on commit e72c92b

Please sign in to comment.