Skip to content

Commit

Permalink
Move batching and field logic from inputter to dsets (OpenNMT#1196)
Browse files Browse the repository at this point in the history
* Move batching and field logic from inputter to dsets.
* Remove _feature_tokenize from inputter.
* Remove src_lengths from audio.
* Make lengths a torch.int instead of self.dtype in AudioSeqField.
* Don't output src_lengths in AudioDataset.make_examples since they're no longer necessary
* Add temp fix for checking if data is text in make_features.
  • Loading branch information
flauted authored and vince62s committed Jan 23, 2019
1 parent 272ff2a commit bd465cc
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 86 deletions.
61 changes: 59 additions & 2 deletions onmt/inputters/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from tqdm import tqdm

import torch
from torchtext.data import Field

from onmt.inputters.dataset_base import DatasetBase

Expand Down Expand Up @@ -115,5 +116,61 @@ def make_examples(
window_stride, window, normalize_audio
)

yield {side: spect, side + '_path': line.strip(),
side + '_lengths': spect.size(1), 'indices': i}
yield {side: spect, side + '_path': line.strip(), 'indices': i}


class AudioSeqField(Field):
def __init__(self, preprocessing=None, postprocessing=None,
include_lengths=False, batch_first=False, pad_index=0,
is_target=False):
super(AudioSeqField, self).__init__(
sequential=True, use_vocab=False, init_token=None,
eos_token=None, fix_length=False, dtype=torch.float,
preprocessing=preprocessing, postprocessing=postprocessing,
lower=False, tokenize=None, include_lengths=include_lengths,
batch_first=batch_first, pad_token=pad_index, unk_token=None,
pad_first=False, truncate_first=False, stop_words=None,
is_target=is_target
)

def pad(self, minibatch):
assert not self.pad_first and not self.truncate_first \
and not self.fix_length and self.sequential
minibatch = list(minibatch)
lengths = [x.size(1) for x in minibatch]
max_len = max(lengths)
nfft = minibatch[0].size(0)
sounds = torch.full((len(minibatch), 1, nfft, max_len), self.pad_token)
for i, (spect, len_) in enumerate(zip(minibatch, lengths)):
sounds[i, :, :, 0:len_] = spect
if self.include_lengths:
return (sounds, lengths)
return sounds

def numericalize(self, arr, device=None):
assert self.use_vocab is False
if self.include_lengths and not isinstance(arr, tuple):
raise ValueError("Field has include_lengths set to True, but "
"input data is not a tuple of "
"(data batch, batch lengths).")
if isinstance(arr, tuple):
arr, lengths = arr
lengths = torch.tensor(lengths, dtype=torch.int, device=device)

if self.postprocessing is not None:
arr = self.postprocessing(arr, None)

if self.sequential and not self.batch_first:
arr.permute(3, 0, 1, 2)
if self.sequential:
arr = arr.contiguous()

if self.include_lengths:
return arr, lengths
return arr


def audio_fields(base_name, **kwargs):
audio = AudioSeqField(pad_index=0, batch_first=True, include_lengths=True)

return [(base_name, audio)]
20 changes: 20 additions & 0 deletions onmt/inputters/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

import os

import torch
from torchtext.data import Field

from onmt.inputters.dataset_base import DatasetBase

# domain specific dependencies
Expand Down Expand Up @@ -62,3 +65,20 @@ def make_examples(
and img.size(2) <= truncate[1]):
continue
yield {side: img, side + '_path': filename, 'indices': i}


def batch_img(data, vocab):
c = data[0].size(0)
h = max([t.size(1) for t in data])
w = max([t.size(2) for t in data])
imgs = torch.zeros(len(data), c, h, w).fill_(1)
for i, img in enumerate(data):
imgs[i, :, 0:img.size(1), 0:img.size(2)] = img
return imgs


def image_fields(base_name, **kwargs):
img = Field(
use_vocab=False, dtype=torch.float,
postprocessing=batch_img, sequential=False)
return [(base_name, img)]
112 changes: 28 additions & 84 deletions onmt/inputters/inputter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from torchtext.data import Field
from torchtext.vocab import Vocab

from onmt.inputters.text_dataset import TextDataset
from onmt.inputters.image_dataset import ImageDataset
from onmt.inputters.audio_dataset import AudioDataset
from onmt.inputters.text_dataset import TextDataset, text_fields
from onmt.inputters.image_dataset import ImageDataset, image_fields
from onmt.inputters.audio_dataset import AudioDataset, audio_fields
from onmt.utils.logging import logger

import gc
Expand Down Expand Up @@ -51,37 +51,6 @@ def make_tgt(data, vocab):
return alignment


def make_img(data, vocab):
c = data[0].size(0)
h = max([t.size(1) for t in data])
w = max([t.size(2) for t in data])
imgs = torch.zeros(len(data), c, h, w).fill_(1)
for i, img in enumerate(data):
imgs[i, :, 0:img.size(1), 0:img.size(2)] = img
return imgs


def make_audio(data, vocab):
""" batch audio data """
nfft = data[0].size(0)
t = max([t.size(1) for t in data])
sounds = torch.zeros(len(data), 1, nfft, t)
for i, spect in enumerate(data):
sounds[i, :, :, 0:spect.size(1)] = spect
return sounds


# mix this with partial
def _feature_tokenize(
string, layer=0, tok_delim=None, feat_delim=None, truncate=None):
tokens = string.split(tok_delim)
if truncate is not None:
tokens = tokens[:truncate]
if feat_delim is not None:
tokens = [t.split(feat_delim)[layer] for t in tokens]
return tokens


def get_fields(
src_data_type,
n_src_feats,
Expand Down Expand Up @@ -110,51 +79,23 @@ def get_fields(
'it is not possible to use dynamic_dict with non-text input'
fields = {'src': [], 'tgt': []}

if src_data_type == 'text':
feat_delim = u"│" if n_src_feats > 0 else None
for i in range(n_src_feats + 1):
name = "src_feat_" + str(i - 1) if i > 0 else "src"
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=src_truncate,
feat_delim=feat_delim)
use_len = i == 0
feat = Field(
pad_token=pad, tokenize=tokenize, include_lengths=use_len)
fields['src'].append((name, feat))
elif src_data_type == 'img':
img = Field(
use_vocab=False, dtype=torch.float,
postprocessing=make_img, sequential=False)
fields["src"].append(('src', img))
else:
audio = Field(
use_vocab=False, dtype=torch.float,
postprocessing=make_audio, sequential=False)
fields["src"].append(('src', audio))

if src_data_type == 'audio':
# only audio has src_lengths
length = Field(use_vocab=False, dtype=torch.long, sequential=False)
fields["src_lengths"] = [("src_lengths", length)]

# below this: things defined no matter what the data source type is
feat_delim = u"│" if n_tgt_feats > 0 else None
for i in range(n_tgt_feats + 1):
name = "tgt_feat_" + str(i - 1) if i > 0 else "tgt"
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=tgt_truncate,
feat_delim=feat_delim)

feat = Field(
init_token=bos,
eos_token=eos,
pad_token=pad,
tokenize=tokenize)
fields['tgt'].append((name, feat))
fields_getters = {"text": text_fields,
"img": image_fields,
"audio": audio_fields}

src_field_kwargs = {"n_feats": n_src_feats,
"include_lengths": True,
"pad": pad, "bos": None, "eos": None,
"truncate": src_truncate}
fields["src"] = fields_getters[src_data_type](
'src', **src_field_kwargs)

tgt_field_kwargs = {"n_feats": n_tgt_feats,
"include_lengths": False,
"pad": pad, "bos": bos, "eos": eos,
"truncate": tgt_truncate}
fields['tgt'] = fields_getters["text"](
'tgt', **tgt_field_kwargs)

indices = Field(use_vocab=False, dtype=torch.long, sequential=False)
fields["indices"] = [('indices', indices)]
Expand Down Expand Up @@ -220,12 +161,9 @@ def make_features(batch, side):
data, lengths = batch.__dict__[side]
else:
data = batch.__dict__[side]
if side == 'src' and hasattr(batch, 'src_lengths'):
lengths = batch.src_lengths
else:
lengths = None
lengths = None

if isinstance(batch.__dict__[side], tuple) or side == 'tgt':
if batch.src_is_text or side == 'tgt': # this is temporary, see #1196
# cat together layers, producing a 3d output tensor for src text
# and for tgt (which is assumed to be text)
feat_start = side + "_feat_"
Expand Down Expand Up @@ -446,6 +384,12 @@ def _pool(data, random_shuffler):
self.batch_size_fn):
self.batches.append(sorted(b, key=self.sort_key))

def __iter__(self):
# temporary fix: See #1196
for batch in super(OrderedIterator, self).__iter__():
batch.src_is_text = isinstance(self.dataset, TextDataset)
yield batch


class DatasetLazyIter(object):
"""
Expand Down
50 changes: 50 additions & 0 deletions onmt/inputters/text_dataset.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# -*- coding: utf-8 -*-
from functools import partial

from torchtext.data import Field

from onmt.inputters.dataset_base import DatasetBase

Expand Down Expand Up @@ -41,3 +44,50 @@ def make_examples(cls, sequences, side):
sequences = cls._read_file(sequences)
for i, seq in enumerate(sequences):
yield {side: seq, "indices": i}


# mix this with partial
def _feature_tokenize(
string, layer=0, tok_delim=None, feat_delim=None, truncate=None):
tokens = string.split(tok_delim)
if truncate is not None:
tokens = tokens[:truncate]
if feat_delim is not None:
tokens = [t.split(feat_delim)[layer] for t in tokens]
return tokens


def text_fields(base_name, **kwargs):
"""Create text fields.
Args:
base_name (str)
n_feats (int)
include_lengths (bool)
pad (str, optional): Defaults to <blank>.
bos (str or NoneType, optional): Defaults to <s>
eos (str or NoneType, optional): Defaults to </s>
truncate (bool or NoneType, optional): Defaults to None.
"""

n_feats = kwargs["n_feats"]
include_lengths = kwargs["include_lengths"]
pad = kwargs.get("pad", "<blank>")
bos = kwargs.get("bos", "<s>")
eos = kwargs.get("eos", "</s>")
truncate = kwargs.get("truncate", None)
fields_ = []
feat_delim = u"│" if n_feats > 0 else None
for i in range(n_feats + 1):
name = base_name + "_feat_" + str(i - 1) if i > 0 else base_name
tokenize = partial(
_feature_tokenize,
layer=i,
truncate=truncate,
feat_delim=feat_delim)
use_len = i == 0 and include_lengths
feat = Field(
init_token=bos, eos_token=eos,
pad_token=pad, tokenize=tokenize,
include_lengths=use_len)
fields_.append((name, feat))
return fields_

0 comments on commit bd465cc

Please sign in to comment.