Skip to content

Commit

Permalink
Multilevel text field (#1216)
Browse files Browse the repository at this point in the history
* Move batching and field logic from inputter to dsets.

* Remove _feature_tokenize from inputter.

* Remove src_lengths from audio.

* Make lengths a torch.int instead of self.dtype in AudioSeqField.

* Don't output src_lengths in AudioDataset.make_examples since they're no longer necessary

* Add temp fix for checking if data is text in make_features.

* First pass at a multi-level field design.

* Remove some unused code.

* Remove batch.src_is_text attr.

* Clean up by adding __iter__ to MultiField.

* Fix extract embeddings.

* Clean build_vocab (including bad indentation level), incorporate #1199, rename TextMultiField attrs.

* Remove make_features.

* Update semantics in direct calls to batch.src, batch.tgt.

* Test for old-style text fields while checking for old-style vocab.
  • Loading branch information
flauted authored and vince62s committed Jan 25, 2019
1 parent 5d6f23b commit 93930ab
Show file tree
Hide file tree
Showing 13 changed files with 182 additions and 118 deletions.
5 changes: 2 additions & 3 deletions onmt/inputters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Inputters implement the logic of transforming raw data to vectorized inputs,
e.g., from a line of text to a sequence of embeddings.
"""
from onmt.inputters.inputter import make_features, \
from onmt.inputters.inputter import \
load_old_vocab, get_fields, OrderedIterator, \
build_dataset, build_vocab, old_style_vocab
from onmt.inputters.dataset_base import DatasetBase
Expand All @@ -12,8 +12,7 @@
from onmt.inputters.audio_dataset import AudioDataset


__all__ = ['DatasetBase', 'make_features',
'load_old_vocab', 'get_fields',
__all__ = ['DatasetBase', 'load_old_vocab', 'get_fields',
'build_dataset', 'old_style_vocab',
'build_vocab', 'OrderedIterator',
'TextDataset', 'ImageDataset', 'AudioDataset']
3 changes: 2 additions & 1 deletion onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def __init__(self, fields, src_examples_iter, tgt_examples_iter,
if dynamic_dict:
src_field = fields['src'][0][1]
tgt_field = fields['tgt'][0][1]
# this assumes src_field and tgt_field are both text
src_vocab, ex_dict = self._dynamic_dict(
ex_dict, src_field, tgt_field)
ex_dict, src_field.base_field, tgt_field.base_field)
self.src_vocabs.append(src_vocab)
ex_fields = {k: v for k, v in fields.items() if k in ex_dict}
ex = Example.fromdict(ex_dict, ex_fields)
Expand Down
133 changes: 75 additions & 58 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from torchtext.data import Field
from torchtext.vocab import Vocab

from onmt.inputters.text_dataset import TextDataset, text_fields
from onmt.inputters.text_dataset import TextDataset, text_fields,\
TextMultiField
from onmt.inputters.image_dataset import ImageDataset, image_fields
from onmt.inputters.audio_dataset import AudioDataset, audio_fields
from onmt.utils.logging import logger
Expand Down Expand Up @@ -126,6 +127,15 @@ def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
returns: a dictionary whose keys are the field names and whose values
are lists of (name, Field) pairs
"""
if _old_style_field_list(vocab): # upgrade to multifield
fields = vocab
for base_name, vals in fields.items():
if ((base_name == 'src' and data_type == 'text') or
base_name == 'tgt'):
assert not isinstance(vals[0][1], TextMultiField)
fields[base_name] = [(base_name, TextMultiField(
vals[0][0], vals[0][1], vals[1:]))]
return fields
vocab = dict(vocab)
n_src_features = sum('src_feat_' in k for k in vocab)
n_tgt_features = sum('tgt_feat_' in k for k in vocab)
Expand All @@ -134,12 +144,17 @@ def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
)
for k, vals in fields.items():
for n, f in vals:
if n in vocab:
f.vocab = vocab[n]
try:
f_iter = iter(f)
except TypeError:
f_iter = [(n, f)]
for sub_n, sub_f in f_iter:
if sub_n in vocab:
sub_f.vocab = vocab[sub_n]
return fields


def old_style_vocab(vocab):
def _old_style_vocab(vocab):
"""
vocab: some object loaded from a *.vocab.pt file
returns: whether the object is a list of pairs where the second object
Expand All @@ -153,29 +168,14 @@ def old_style_vocab(vocab):
any(isinstance(v[1], Vocab) for v in vocab)


def make_features(batch, side):
"""
batch: a batch object
side: 'src' or 'tgt'
returns the tensor with features concatenated, and the lengths (if present)
or None.
"""
assert side in ['src', 'tgt']
if isinstance(batch.__dict__[side], tuple):
data, lengths = batch.__dict__[side]
else:
data = batch.__dict__[side]
lengths = None
def _old_style_field_list(vocab):
# if tgt isn't using TextMultiField, then no text field is.
return not _old_style_vocab(vocab) and not isinstance(
vocab['tgt'][0][1], TextMultiField)

if batch.src_is_text or side == 'tgt': # this is temporary, see #1196
# cat together layers, producing a 3d output tensor for src text
# and for tgt (which is assumed to be text)
feat_start = side + "_feat_"
feat_names = sorted(k for k in batch.__dict__ if feat_start in k)
levels = [data] + [batch.__dict__[k] for k in feat_names]
data = torch.cat([level.unsqueeze(2) for level in levels], 2)

return data, lengths
def old_style_vocab(vocab):
return _old_style_vocab(vocab) or _old_style_field_list(vocab)


def filter_example(ex, use_src_len=True, use_tgt_len=True,
Expand Down Expand Up @@ -250,6 +250,24 @@ def _build_field_vocab(field, counter, **kwargs):
field.vocab = field.vocab_cls(counter, specials=specials, **kwargs)


def _load_vocab(vocab_path, name, counters):
# counters changes in place
vocab = _read_vocab_file(vocab_path, name)
vocab_size = len(vocab)
logger.info('Loaded %s vocab has %d tokens.' % (name, vocab_size))
for i, token in enumerate(vocab):
# keep the order of tokens specified in the vocab file by
# adding them to the counter with decreasing counting values
counters[name][token] = vocab_size - i
return vocab, vocab_size


def _build_fv_from_multifield(multifield, counters, build_fv_args):
for name, field in multifield:
_build_field_vocab(field, counters[name], **build_fv_args[name])
logger.info(" * %s vocab size: %d." % (name, len(field.vocab)))


def build_vocab(train_dataset_files, fields, data_type, share_vocab,
src_vocab_path, src_vocab_size, src_words_min_frequency,
tgt_vocab_path, tgt_vocab_size, tgt_words_min_frequency):
Expand All @@ -271,26 +289,18 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
Returns:
Dict of Fields
"""
counters = {k: Counter() for k, v in chain.from_iterable(fields.values())}
counters = defaultdict(Counter)

# Load vocabulary
if src_vocab_path:
src_vocab = _read_vocab_file(src_vocab_path, "src")
src_vocab_size = len(src_vocab)
logger.info('Loaded source vocab has %d tokens.' % src_vocab_size)
for i, token in enumerate(src_vocab):
# keep the order of tokens specified in the vocab file by
# adding them to the counter with decreasing counting values
counters['src'][token] = src_vocab_size - i
src_vocab, src_vocab_size = _load_vocab(
src_vocab_path, "src", counters)
else:
src_vocab = None

if tgt_vocab_path:
tgt_vocab = _read_vocab_file(tgt_vocab_path, "tgt")
tgt_vocab_size = len(tgt_vocab)
logger.info('Loaded source vocab has %d tokens.' % tgt_vocab_size)
for i, token in enumerate(tgt_vocab):
counters['tgt'][token] = tgt_vocab_size - i
tgt_vocab, tgt_vocab_size = _load_vocab(
tgt_vocab_path, "tgt", counters)
else:
tgt_vocab = None

Expand All @@ -299,11 +309,20 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
logger.info(" * reloading %s." % path)
for ex in dataset.examples:
for name, field in chain.from_iterable(fields.values()):
has_vocab = (name == 'src' and src_vocab) or \
(name == 'tgt' and tgt_vocab)
if field.sequential and not has_vocab:
val = getattr(ex, name, None)
counters[name].update(val)
try:
f_iter = iter(field)
except TypeError:
f_iter = [(name, field)]
all_data = [getattr(ex, name, None)]
else:
all_data = getattr(ex, name)
for (sub_n, sub_f), fd in zip(
f_iter, all_data):
has_vocab = (sub_n == 'src' and src_vocab) or \
(sub_n == 'tgt' and tgt_vocab)
if sub_f.sequential and not has_vocab:
val = fd
counters[sub_n].update(val)

# Drop the none-using from memory but keep the last
if i < len(train_dataset_files) - 1:
Expand All @@ -314,23 +333,27 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
del dataset
gc.collect()

for name, field in fields["tgt"]:
_build_field_vocab(field, counters[name])
logger.info(" * %s vocab size: %d." % (name, len(field.vocab)))
build_fv_args = defaultdict(dict)
build_fv_args["src"] = dict(
max_size=src_vocab_size, min_freq=src_words_min_frequency)
build_fv_args["tgt"] = dict(
max_size=tgt_vocab_size, min_freq=tgt_words_min_frequency)
assert len(fields["tgt"]) == 1
tgt_multifield = fields["tgt"][0][1]
_build_fv_from_multifield(tgt_multifield, counters, build_fv_args)
if data_type == 'text':
for name, field in fields["src"]:
_build_field_vocab(field, counters[name])
logger.info(" * %s vocab size: %d." % (name, len(field.vocab)))
assert len(fields["src"]) == 1
src_multifield = fields["src"][0][1]
_build_fv_from_multifield(src_multifield, counters, build_fv_args)
if share_vocab:
# `tgt_vocab_size` is ignored when sharing vocabularies
logger.info(" * merging src and tgt vocab...")
src_field = fields['src'][0][1]
tgt_field = fields['tgt'][0][1]
src_field = src_multifield.base_field
tgt_field = tgt_multifield.base_field
_merge_field_vocabs(
src_field, tgt_field, vocab_size=src_vocab_size,
min_freq=src_words_min_frequency)
logger.info(" * merged vocab size: %d." % len(src_field.vocab))

return fields # is the return necessary?


Expand Down Expand Up @@ -388,12 +411,6 @@ def _pool(data, random_shuffler):
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))

def __iter__(self):
# temporary fix: See #1196
for batch in super(OrderedIterator, self).__iter__():
batch.src_is_text = isinstance(self.dataset, TextDataset)
yield batch


class DatasetLazyIter(object):
"""
Expand Down
47 changes: 43 additions & 4 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# -*- coding: utf-8 -*-
from functools import partial

from torchtext.data import Field
import torch
from torchtext.data import Field, RawField

from onmt.inputters.dataset_base import DatasetBase

Expand All @@ -24,8 +25,8 @@ class TextDataset(DatasetBase):
@staticmethod
def sort_key(ex):
if hasattr(ex, "tgt"):
return len(ex.src), len(ex.tgt)
return len(ex.src)
return len(ex.src[0]), len(ex.tgt[0])
return len(ex.src[0])

@classmethod
def make_examples(cls, sequences, side):
Expand Down Expand Up @@ -57,6 +58,42 @@ def _feature_tokenize(
return tokens


class TextMultiField(RawField):
def __init__(self, base_name, base_field, feats_fields):
super(TextMultiField, self).__init__()
self.fields = [(base_name, base_field)]
for name, ff in sorted(feats_fields, key=lambda kv: kv[0]):
self.fields.append((name, ff))

@property
def base_field(self):
return self.fields[0][1]

def process(self, batch, device=None):
# batch (list(list(list))): batch_size x len(self.fields) x seq_len
batch_by_feat = list(zip(*batch))
base_data = self.base_field.process(batch_by_feat[0], device=device)
if self.base_field.include_lengths:
# lengths: batch_size
base_data, lengths = base_data

feats = [ff.process(batch_by_feat[i], device=device)
for i, (_, ff) in enumerate(self.fields[1:], 1)]
levels = [base_data] + feats
# data: seq_len x batch_size x len(self.fields)
data = torch.stack(levels, 2)
if self.base_field.include_lengths:
return data, lengths
else:
return data

def preprocess(self, x):
return [f.preprocess(x) for _, f in self.fields]

def __getitem__(self, item):
return self.fields[item]


def text_fields(base_name, **kwargs):
"""Create text fields.
Args:
Expand Down Expand Up @@ -90,4 +127,6 @@ def text_fields(base_name, **kwargs):
pad_token=pad, tokenize=tokenize,
include_lengths=use_len)
fields_.append((name, feat))
return fields_
assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return [(base_name, field)]
32 changes: 19 additions & 13 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,20 @@
from onmt.utils.logging import logger


def build_embeddings(opt, word_field, feat_fields, for_encoder=True):
def build_embeddings(opt, text_field, for_encoder=True):
"""
Args:
opt: the option in current environment.
word_dict(Vocab): words dictionary.
feature_dicts([Vocab], optional): a list of feature dictionary.
text_field(TextMultiField): word and feats field.
for_encoder(bool): build Embeddings for encoder or decoder?
"""
emb_dim = opt.src_word_vec_size if for_encoder else opt.tgt_word_vec_size

word_padding_idx = word_field.vocab.stoi[word_field.pad_token]
num_word_embeddings = len(word_field.vocab)
pad_indices = [f.vocab.stoi[f.pad_token] for _, f in text_field]
word_padding_idx, feat_pad_indices = pad_indices[0], pad_indices[1:]

feat_pad_indices = [ff.vocab.stoi[ff.pad_token] for ff in feat_fields]
num_feat_embeddings = [len(ff.vocab) for ff in feat_fields]
num_embs = [len(f.vocab) for _, f in text_field]
num_word_embeddings, num_feat_embeddings = num_embs[0], num_embs[1:]

fix_word_vecs = opt.fix_word_vecs_enc if for_encoder \
else opt.fix_word_vecs_dec
Expand Down Expand Up @@ -193,7 +192,9 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None):
# Build encoder.
if model_opt.model_type == "text":
src_fields = [f for n, f in fields['src']]
src_emb = build_embeddings(model_opt, src_fields[0], src_fields[1:])
assert len(src_fields) == 1
src_field = src_fields[0]
src_emb = build_embeddings(model_opt, src_field)
encoder = build_encoder(model_opt, src_emb)
elif model_opt.model_type == "img":
# why is build_encoder not used here?
Expand Down Expand Up @@ -226,13 +227,15 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None):

# Build decoder.
tgt_fields = [f for n, f in fields['tgt']]
assert len(tgt_fields) == 1
tgt_field = tgt_fields[0]
tgt_emb = build_embeddings(
model_opt, tgt_fields[0], tgt_fields[1:], for_encoder=False)
model_opt, tgt_field, for_encoder=False)

# Share the embedding matrix - preprocess with share_vocab required.
if model_opt.share_embeddings:
# src/tgt vocab should be the same if `-share_vocab` is specified.
assert src_fields[0].vocab == tgt_fields[0].vocab, \
assert src_field.base_field.vocab == tgt_field.base_field.vocab, \
"preprocess with -share_vocab if you use share_embeddings"

tgt_emb.word_lut.weight = src_emb.word_lut.weight
Expand All @@ -250,14 +253,17 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None):
else:
gen_func = nn.LogSoftmax(dim=-1)
generator = nn.Sequential(
nn.Linear(model_opt.dec_rnn_size, len(fields["tgt"][0][1].vocab)),
nn.Linear(model_opt.dec_rnn_size,
len(fields["tgt"][0][1].base_field.vocab)),
gen_func
)
if model_opt.share_decoder_embeddings:
generator[0].weight = decoder.embeddings.word_lut.weight
else:
vocab_size = len(fields["tgt"][0][1].vocab)
pad_idx = fields["tgt"][0][1].vocab.stoi[fields["tgt"][0][1].pad_token]
assert len(fields["tgt"]) == 1
tgt_base_field = fields["tgt"][0][1].base_field
vocab_size = len(tgt_base_field.vocab)
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)

# Load the model states from checkpoint or initialize them.
Expand Down
Loading

0 comments on commit 93930ab

Please sign in to comment.