Skip to content

Commit

Permalink
Simplify fields structure (OpenNMT#1299)
Browse files Browse the repository at this point in the history
  • Loading branch information
flauted authored and vince62s committed Feb 15, 2019
1 parent a0617b8 commit bd7096d
Show file tree
Hide file tree
Showing 13 changed files with 113 additions and 100 deletions.
4 changes: 2 additions & 2 deletions onmt/inputters/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,6 @@ def numericalize(self, arr, device=None):
return arr


def audio_fields(base_name, **kwargs):
def audio_fields(**kwargs):
audio = AudioSeqField(pad_index=0, batch_first=True, include_lengths=True)
return [(base_name, audio)]
return audio
16 changes: 10 additions & 6 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class Dataset(TorchtextDataset):
torchtext's iterators then know how to use these examples to make batches.
Args:
fields (dict[str, List[Tuple[str, Field]]]): a dict with the structure
fields (dict[str, Field]): a dict with the structure
returned by :func:`onmt.inputters.get_fields()`. Usually
that means the dataset side, ``"src"`` or ``"tgt"``. Keys match
the keys of items yielded by the ``readers``, while values
Expand Down Expand Up @@ -119,18 +119,22 @@ def __init__(self, fields, readers, data, dirs, sort_key,
examples = []
for ex_dict in starmap(_join_dicts, zip(*read_iters)):
if can_copy:
src_field = fields['src'][0][1]
tgt_field = fields['tgt'][0][1]
src_field = fields['src']
tgt_field = fields['tgt']
# this assumes src_field and tgt_field are both text
src_ex_vocab, ex_dict = _dynamic_dict(
ex_dict, src_field.base_field, tgt_field.base_field)
self.src_vocabs.append(src_ex_vocab)
ex_fields = {k: v for k, v in fields.items() if k in ex_dict}
ex_fields = {k: [(k, v)] for k, v in fields.items() if
k in ex_dict}
ex = Example.fromdict(ex_dict, ex_fields)
examples.append(ex)

# the dataset's self.fields should have the same attributes as examples
fields = dict(chain.from_iterable(ex_fields.values()))
# fields needs to have only keys that examples have as attrs
fields = []
for _, nf_list in ex_fields.items():
assert len(nf_list) == 1
fields.append(nf_list[0])

super(Dataset, self).__init__(examples, fields, filter_pred)

Expand Down
4 changes: 2 additions & 2 deletions onmt/inputters/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,8 @@ def batch_img(data, vocab):
return imgs


def image_fields(base_name, **kwargs):
def image_fields(**kwargs):
img = Field(
use_vocab=False, dtype=torch.float,
postprocessing=batch_img, sequential=False)
return [(base_name, img)]
return img
109 changes: 62 additions & 47 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,18 +88,15 @@ def get_fields(
:class:`TextDataReader` - see there for more details).
Returns:
A dictionary. The keys are strings whose names correspond to the
keys of the dictionaries yielded by the make_examples methods of
various dataset classes. The values are lists of (name, Field)
pairs, where the name is a string which will become the name of
an attribute of an example.
A dict mapping names to fields. These names need to match
the dataset example attributes.
"""

assert src_data_type in ['text', 'img', 'audio'], \
"Data type not implemented"
assert not dynamic_dict or src_data_type == 'text', \
'it is not possible to use dynamic_dict with non-text input'
fields = {'src': [], 'tgt': []}
fields = {}

fields_getters = {"text": text_fields,
"img": image_fields,
Expand All @@ -108,30 +105,30 @@ def get_fields(
src_field_kwargs = {"n_feats": n_src_feats,
"include_lengths": True,
"pad": pad, "bos": None, "eos": None,
"truncate": src_truncate}
fields["src"] = fields_getters[src_data_type](
'src', **src_field_kwargs)
"truncate": src_truncate,
"base_name": "src"}
fields["src"] = fields_getters[src_data_type](**src_field_kwargs)

tgt_field_kwargs = {"n_feats": n_tgt_feats,
"include_lengths": False,
"pad": pad, "bos": bos, "eos": eos,
"truncate": tgt_truncate}
fields['tgt'] = fields_getters["text"](
'tgt', **tgt_field_kwargs)
"truncate": tgt_truncate,
"base_name": "tgt"}
fields["tgt"] = fields_getters["text"](**tgt_field_kwargs)

indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
fields["indices"] = [('indices', indices)]
fields["indices"] = indices

if dynamic_dict:
src_map = Field(
use_vocab=False, dtype=torch.float,
postprocessing=make_src, sequential=False)
fields["src_map"] = [("src_map", src_map)]
fields["src_map"] = src_map

align = Field(
use_vocab=False, dtype=torch.long,
postprocessing=make_tgt, sequential=False)
fields["alignment"] = [('alignment', align)]
fields["alignment"] = align

return fields

Expand All @@ -144,42 +141,52 @@ def load_old_vocab(vocab, data_type="text", dynamic_dict=False):
format formerly saved in *.vocab.pt files. Or, text data
not using a :class:`TextMultiField`.
data_type (str): text, img, or audio
dynamic_dict (str): Used for copy attention.
dynamic_dict (bool): Used for copy attention.
Returns:
a dictionary whose keys are the field names and whose values
are lists of (name, Field) pairs, using :class:`TextMultiField`s
as appropriate.
a dictionary whose keys are the field names and whose values Fields.
"""

if _old_style_vocab(vocab):
# List[Tuple[str, Vocab]] -> List[Tuple[str, Field]]
# -> dict[str, Field]
vocab = dict(vocab)
n_src_features = sum('src_feat_' in k for k in vocab)
n_tgt_features = sum('tgt_feat_' in k for k in vocab)
fields = get_fields(
data_type, n_src_features, n_tgt_features,
dynamic_dict=dynamic_dict)
for n, f in fields.items():
try:
f_iter = iter(f)
except TypeError:
f_iter = [(n, f)]
for sub_n, sub_f in f_iter:
if sub_n in vocab:
sub_f.vocab = vocab[sub_n]
return fields

if _old_style_field_list(vocab): # upgrade to multifield
# Dict[str, List[Tuple[str, Field]]]
# doesn't change structure - don't return early.
fields = vocab
for base_name, vals in fields.items():
if ((base_name == 'src' and data_type == 'text') or
base_name == 'tgt'):
assert not isinstance(vals[0][1], TextMultiField)
fields[base_name] = [(base_name, TextMultiField(
vals[0][0], vals[0][1], vals[1:]))]
return fields
vocab = dict(vocab)
n_src_features = sum('src_feat_' in k for k in vocab)
n_tgt_features = sum('tgt_feat_' in k for k in vocab)
fields = get_fields(
data_type, n_src_features, n_tgt_features, dynamic_dict=dynamic_dict
)
for k, vals in fields.items():
for n, f in vals:
try:
f_iter = iter(f)
except TypeError:
f_iter = [(n, f)]
for sub_n, sub_f in f_iter:
if sub_n in vocab:
sub_f.vocab = vocab[sub_n]

if _old_style_nesting(vocab):
# Dict[str, List[Tuple[str, Field]]] -> List[Tuple[str, Field]]
# -> dict[str, Field]
fields = dict(list(chain.from_iterable(vocab.values())))

return fields


def _old_style_vocab(vocab):
"""Detect old-style vocabs.
"""Detect old-style vocabs (``List[Tuple[str, torchtext.data.Vocab]]``).
Args:
vocab: some object loaded from a *.vocab.pt file
Expand All @@ -197,9 +204,18 @@ def _old_style_vocab(vocab):
any(isinstance(v[1], Vocab) for v in vocab)


def _old_style_nesting(vocab):
"""Detect old-style nesting (``dict[str, List[Tuple[str, Field]]]``)."""
return isinstance(vocab, dict) and \
any(isinstance(v, list) for v in vocab.values())


def _old_style_field_list(vocab):
"""Detect old-style text fields.
Not old style vocab, old nesting, and text-type fields not using
``TextMultiField``.
Args:
vocab: some object loaded from a *.vocab.pt file
Expand All @@ -209,13 +225,14 @@ def _old_style_field_list(vocab):
"""

# if tgt isn't using TextMultiField, then no text field is.
return not _old_style_vocab(vocab) and not isinstance(
vocab['tgt'][0][1], TextMultiField)
return (not _old_style_vocab(vocab)) and _old_style_nesting(vocab) and \
(not isinstance(vocab['tgt'][0][1], TextMultiField))


def old_style_vocab(vocab):
""":func:`_old_style_vocab()` OR :func:`_old_style_field_list()`."""
return _old_style_vocab(vocab) or _old_style_field_list(vocab)
"""The vocab/fields need updated."""
return _old_style_vocab(vocab) or _old_style_field_list(vocab) or \
_old_style_nesting(vocab)


def filter_example(ex, use_src_len=True, use_tgt_len=True,
Expand Down Expand Up @@ -299,7 +316,7 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
Args:
train_dataset_files: a list of train dataset pt file.
fields (dict[str, List[Tuple[str, Field]]]): fields to build vocab for.
fields (dict[str, Field]): fields to build vocab for.
data_type (str): A supported data type string.
share_vocab (bool): share source and target vocabulary?
src_vocab_path (str): Path to src vocabulary file.
Expand Down Expand Up @@ -336,7 +353,7 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
dataset = torch.load(path)
logger.info(" * reloading %s." % path)
for ex in dataset.examples:
for name, field in chain.from_iterable(fields.values()):
for name, field in fields.items():
try:
f_iter = iter(field)
except TypeError:
Expand Down Expand Up @@ -366,16 +383,14 @@ def build_vocab(train_dataset_files, fields, data_type, share_vocab,
max_size=src_vocab_size, min_freq=src_words_min_frequency)
build_fv_args["tgt"] = dict(
max_size=tgt_vocab_size, min_freq=tgt_words_min_frequency)
assert len(fields["tgt"]) == 1
tgt_multifield = fields["tgt"][0][1]
tgt_multifield = fields["tgt"]
_build_fv_from_multifield(
tgt_multifield,
counters,
build_fv_args,
size_multiple=vocab_size_multiple if not share_vocab else 1)
if data_type == 'text':
assert len(fields["src"]) == 1
src_multifield = fields["src"][0][1]
src_multifield = fields["src"]
_build_fv_from_multifield(
src_multifield,
counters,
Expand Down Expand Up @@ -506,7 +521,7 @@ class DatasetLazyIter(object):
Args:
dataset_paths: a list containing the locations of dataset files.
fields (dict[str, List[Tuple[str, Field]]]): fields dict for the
fields (dict[str, Field]): fields dict for the
datasets.
batch_size (int): batch size.
batch_size_fn: custom batch process function.
Expand Down
7 changes: 4 additions & 3 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def __getitem__(self, item):
return self.fields[item]


def text_fields(base_name, **kwargs):
def text_fields(**kwargs):
"""Create text fields.
Args:
Expand All @@ -164,11 +164,12 @@ def text_fields(base_name, **kwargs):
truncate (bool or NoneType, optional): Defaults to ``None``.
Returns:
List[Tuple[str, TextMultiField]]
TextMultiField
"""

n_feats = kwargs["n_feats"]
include_lengths = kwargs["include_lengths"]
base_name = kwargs["base_name"]
pad = kwargs.get("pad", "<blank>")
bos = kwargs.get("bos", "<s>")
eos = kwargs.get("eos", "</s>")
Expand All @@ -190,4 +191,4 @@ def text_fields(base_name, **kwargs):
fields_.append((name, feat))
assert fields_[0][0] == base_name # sanity check
field = TextMultiField(fields_[0][0], fields_[0][1], fields_[1:])
return [(base_name, field)]
return field
23 changes: 10 additions & 13 deletions onmt/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,25 +103,26 @@ def load_test_model(opt, model_path=None):


def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
"""
"""Build a model from opts.
Args:
model_opt: the option loaded from checkpoint. It's important that
the opts have been updated and validated. See
:class:`onmt.utils.parse.ArgumentParser`.
fields: `Field` objects for the model.
fields (dict[str, torchtext.data.Field]):
`Field` objects for the model.
gpu (bool): whether to use gpu.
checkpoint: the model gnerated by train phase, or a resumed snapshot
model from a stopped training.
gpu_id (int or NoneType): Which GPU to use.
Returns:
the NMTModel.
"""

# Build embeddings.
if model_opt.model_type == "text":
src_fields = [f for n, f in fields['src']]
assert len(src_fields) == 1
src_field = src_fields[0]
src_field = fields["src"]
src_emb = build_embeddings(model_opt, src_field)
else:
src_emb = None
Expand All @@ -130,11 +131,8 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
encoder = build_encoder(model_opt, src_emb)

# Build decoder.
tgt_fields = [f for n, f in fields['tgt']]
assert len(tgt_fields) == 1
tgt_field = tgt_fields[0]
tgt_emb = build_embeddings(
model_opt, tgt_field, for_encoder=False)
tgt_field = fields["tgt"]
tgt_emb = build_embeddings(model_opt, tgt_field, for_encoder=False)

# Share the embedding matrix - preprocess with share_vocab required.
if model_opt.share_embeddings:
Expand Down Expand Up @@ -163,15 +161,14 @@ def build_base_model(model_opt, fields, gpu, checkpoint=None, gpu_id=None):
gen_func = nn.LogSoftmax(dim=-1)
generator = nn.Sequential(
nn.Linear(model_opt.dec_rnn_size,
len(fields["tgt"][0][1].base_field.vocab)),
len(fields["tgt"].base_field.vocab)),
Cast(torch.float32),
gen_func
)
if model_opt.share_decoder_embeddings:
generator[0].weight = decoder.embeddings.word_lut.weight
else:
assert len(fields["tgt"]) == 1
tgt_base_field = fields["tgt"][0][1].base_field
tgt_base_field = fields["tgt"].base_field
vocab_size = len(tgt_base_field.vocab)
pad_idx = tgt_base_field.vocab.stoi[tgt_base_field.pad_token]
generator = CopyGenerator(model_opt.dec_rnn_size, vocab_size, pad_idx)
Expand Down
2 changes: 1 addition & 1 deletion onmt/tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def __init__(self, *args, **kwargs):
self.opt = opt

def get_field(self):
src = onmt.inputters.get_fields("text", 0, 0)["src"][0][1]
src = onmt.inputters.get_fields("text", 0, 0)["src"]
src.base_field.build_vocab([])
return src

Expand Down
Loading

0 comments on commit bd7096d

Please sign in to comment.