Skip to content

Commit

Permalink
Datareaders (OpenNMT#1221)
Browse files Browse the repository at this point in the history
* Start abstracting out dataset readers.

* Remove make_examples in favor of Reader.read

* Uniform spacing around imports.

* _check_deps as a classmethod of the reader.

* Move reader.read calls into DatasetBase init.

* Add 'empty' data reader __init__ to reader's base; delete from TextDataReader.

* Make readers a class attribute of DatasetBase instead of passing as args.

* Revert "Make readers a class attribute of DatasetBase instead of passing as args."

This reverts commit cc8cc98.

* Add from_opt to readers; undo __init__ taking all the args.

* Add tests for data readers.
  • Loading branch information
flauted authored and vince62s committed Feb 4, 2019
1 parent a850f4a commit c8dc587
Show file tree
Hide file tree
Showing 13 changed files with 394 additions and 181 deletions.
15 changes: 10 additions & 5 deletions onmt/inputters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
load_old_vocab, get_fields, OrderedIterator, \
build_dataset, build_vocab, old_style_vocab
from onmt.inputters.dataset_base import DatasetBase
from onmt.inputters.text_dataset import TextDataset
from onmt.inputters.image_dataset import ImageDataset
from onmt.inputters.audio_dataset import AudioDataset
from onmt.inputters.text_dataset import TextDataset, TextDataReader
from onmt.inputters.image_dataset import ImageDataset, ImageDataReader
from onmt.inputters.audio_dataset import AudioDataset, AudioDataReader
from onmt.inputters.datareader_base import DataReaderBase


__all__ = ['DatasetBase', 'load_old_vocab', 'get_fields',
str2reader = {
"text": TextDataReader, "img": ImageDataReader, "audio": AudioDataReader}

__all__ = ['DatasetBase', 'load_old_vocab', 'get_fields', 'DataReaderBase',
'build_dataset', 'old_style_vocab',
'build_vocab', 'OrderedIterator',
'TextDataset', 'ImageDataset', 'AudioDataset']
'TextDataset', 'ImageDataset', 'AudioDataset',
'TextDataReader', 'ImageDataReader', 'AudioDataReader']
100 changes: 52 additions & 48 deletions onmt/inputters/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchtext.data import Field

from onmt.inputters.dataset_base import DatasetBase
from onmt.inputters.datareader_base import DataReaderBase

# imports of datatype-specific dependencies
try:
Expand All @@ -16,34 +17,53 @@
torchaudio, librosa, np = None, None, None


class AudioDataset(DatasetBase):
@staticmethod
def _check_deps():
class AudioDataReader(DataReaderBase):
"""
Args:
sample_rate (int): sample_rate.
window_size (float) : window size for spectrogram in seconds.
window_stride (float): window stride for spectrogram in seconds.
window (str): window type for spectrogram generation.
normalize_audio (bool): subtract spectrogram by mean and divide
by std or not.
truncate (int or NoneType): maximum audio length
(0 or None for unlimited).
"""

def __init__(self, sample_rate=0, window_size=0, window_stride=0,
window=None, normalize_audio=True, truncate=None):
self._check_deps()
self.sample_rate = sample_rate
self.window_size = window_size
self.window_stride = window_stride
self.window = window
self.normalize_audio = normalize_audio
self.truncate = truncate

@classmethod
def from_opt(cls, opt):
return cls(sample_rate=opt.sample_rate, window_size=opt.window_size,
window_stride=opt.window_stride, window=opt.window)

@classmethod
def _check_deps(cls):
if any([torchaudio is None, librosa is None, np is None]):
AudioDataset._raise_missing_dep(
cls._raise_missing_dep(
"torchaudio", "librosa", "numpy")

@staticmethod
def sort_key(ex):
""" Sort using duration time of the sound spectrogram. """
return ex.src.size(1)

@staticmethod
def extract_features(audio_path, sample_rate, truncate, window_size,
window_stride, window, normalize_audio):
def extract_features(self, audio_path):
# torchaudio loading options recently changed. It's probably
# straightforward to rewrite the audio handling to make use of
# up-to-date torchaudio, but in the meantime there is a legacy
# method which uses the old defaults
AudioDataset._check_deps()
sound, sample_rate_ = torchaudio.legacy.load(audio_path)
if truncate and truncate > 0:
if sound.size(0) > truncate:
sound = sound[:truncate]
if self.truncate and self.truncate > 0:
if sound.size(0) > self.truncate:
sound = sound[:self.truncate]

assert sample_rate_ == sample_rate, \
assert sample_rate_ == self.sample_rate, \
'Sample rate of %s != -sample_rate (%d vs %d)' \
% (audio_path, sample_rate_, sample_rate)
% (audio_path, sample_rate_, self.sample_rate)

sound = sound.numpy()
if len(sound.shape) > 1:
Expand All @@ -52,47 +72,28 @@ def extract_features(audio_path, sample_rate, truncate, window_size,
else:
sound = sound.mean(axis=1) # average multiple channels

n_fft = int(sample_rate * window_size)
n_fft = int(self.sample_rate * self.window_size)
win_length = n_fft
hop_length = int(sample_rate * window_stride)
hop_length = int(self.sample_rate * self.window_stride)
# STFT
d = librosa.stft(sound, n_fft=n_fft, hop_length=hop_length,
win_length=win_length, window=window)
win_length=win_length, window=self.window)
spect, _ = librosa.magphase(d)
spect = np.log1p(spect)
spect = torch.FloatTensor(spect)
if normalize_audio:
if self.normalize_audio:
mean = spect.mean()
std = spect.std()
spect.add_(-mean)
spect.div_(std)
return spect

@classmethod
def make_examples(
cls,
data,
src_dir,
side,
sample_rate,
window_size,
window_stride,
window,
normalize_audio,
truncate=None
):
def read(self, data, side, src_dir=None):
"""
Args:
data: sequence of audio paths or path containing these sequences
src_dir (str): location of source audio files.
side (str): 'src' or 'tgt'.
sample_rate (int): sample_rate.
window_size (float) : window size for spectrogram in seconds.
window_stride (float): window stride for spectrogram in seconds.
window (str): window type for spectrogram generation.
normalize_audio (bool): subtract spectrogram by mean and divide
by std or not.
truncate (int): maximum audio length (0 or None for unlimited).
Yields:
a dictionary containing audio data for each line.
Expand All @@ -101,7 +102,7 @@ def make_examples(
"src_dir must be a valid directory if data_type is audio"

if isinstance(data, str):
data = cls._read_file(data)
data = DataReaderBase._read_file(data)

for i, line in enumerate(tqdm(data)):
line = line.decode("utf-8").strip()
Expand All @@ -112,14 +113,17 @@ def make_examples(
assert os.path.exists(audio_path), \
'audio path %s not found' % line

spect = AudioDataset.extract_features(
audio_path, sample_rate, truncate, window_size,
window_stride, window, normalize_audio
)

spect = self.extract_features(audio_path)
yield {side: spect, side + '_path': line, 'indices': i}


class AudioDataset(DatasetBase):
@staticmethod
def sort_key(ex):
""" Sort using duration time of the sound spectrogram. """
return ex.src.size(1)


class AudioSeqField(Field):
"""Defines an audio datatype and instructions for converting to Tensor.
Expand Down
30 changes: 30 additions & 0 deletions onmt/inputters/datareader_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# coding: utf-8


# several data readers need optional dependencies. There's no
# appropriate builtin exception
class MissingDependencyException(Exception):
pass


class DataReaderBase(object):
"""Read data from file system and yield as dicts."""
@classmethod
def from_opt(cls, opt):
return cls()

@classmethod
def _read_file(cls, path):
with open(path, "rb") as f:
for line in f:
yield line

@staticmethod
def _raise_missing_dep(*missing_deps):
"""Raise missing dep exception with standard error message."""
raise MissingDependencyException(
"Could not create reader. Be sure to install "
"the following dependencies: " + ", ".join(missing_deps))

def read(self, data, side, src_dir):
raise NotImplementedError()
34 changes: 5 additions & 29 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
# coding: utf-8

from itertools import chain
from itertools import chain, starmap
from collections import Counter

import torch
from torchtext.data import Example, Dataset
from torchtext.vocab import Vocab


# several data readers need optional dependencies. There's no
# appropriate builtin exception
class MissingDependencyException(Exception):
pass


class DatasetBase(Dataset):
"""
A dataset is an object that accepts sequences of raw data (sentence pairs
Expand Down Expand Up @@ -55,21 +49,16 @@ class DatasetBase(Dataset):
the same structure as in the fields argument passed to the constructor.
"""

def __init__(self, fields, src_examples_iter, tgt_examples_iter,
filter_pred=None):

def __init__(self, fields, readers, data, dirs, filter_pred=None):
dynamic_dict = 'src_map' in fields and 'alignment' in fields

if tgt_examples_iter is not None:
examples_iter = (self._join_dicts(src, tgt) for src, tgt in
zip(src_examples_iter, tgt_examples_iter))
else:
examples_iter = src_examples_iter
read_iters = [r.read(dat[1], dat[0], dir_) for r, dat, dir_
in zip(readers, data, dirs)]

# self.src_vocabs is used in collapse_copy_scores and Translator.py
self.src_vocabs = []
examples = []
for ex_dict in examples_iter:
for ex_dict in starmap(self._join_dicts, zip(*read_iters)):
if dynamic_dict:
src_field = fields['src'][0][1]
tgt_field = fields['tgt'][0][1]
Expand All @@ -86,13 +75,6 @@ def __init__(self, fields, src_examples_iter, tgt_examples_iter,

super(DatasetBase, self).__init__(examples, fields, filter_pred)

@staticmethod
def _raise_missing_dep(*missing_deps):
"""Raise missing dep exception with standard error message."""
raise MissingDependencyException(
"Could not create reader. Be sure to install "
"the following dependencies: " + ", ".join(missing_deps))

def __getattr__(self, attr):
# avoid infinite recursion when fields isn't defined
if 'fields' not in vars(self):
Expand Down Expand Up @@ -133,9 +115,3 @@ def _dynamic_dict(self, example, src_field, tgt_field):
[0] + [src_vocab.stoi[w] for w in tgt] + [0])
example["alignment"] = mask
return src_vocab, example

@classmethod
def _read_file(cls, path):
with open(path, "rb") as f:
for line in f:
yield line
58 changes: 35 additions & 23 deletions onmt/inputters/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torchtext.data import Field

from onmt.inputters.dataset_base import DatasetBase
from onmt.inputters.datareader_base import DataReaderBase

# domain specific dependencies
try:
Expand All @@ -16,57 +17,68 @@
Image, transforms, cv2 = None, None, None


class ImageDataset(DatasetBase):
@staticmethod
def _check_deps():
if any([Image is None, transforms is None, cv2 is None]):
ImageDataset._raise_missing_dep(
"PIL", "torchvision", "cv2")
class ImageDataReader(DataReaderBase):
"""
Args:
truncate: maximum img size ((0,0) or None for unlimited)
channel_size: Number of channels per image.
"""

@staticmethod
def sort_key(ex):
""" Sort using the size of the image: (width, height)."""
return ex.src.size(2), ex.src.size(1)
def __init__(self, truncate=None, channel_size=3):
self._check_deps()
self.truncate = truncate
self.channel_size = channel_size

@classmethod
def make_examples(
cls, images, src_dir, side, truncate=None, channel_size=3
):
def from_opt(cls, opt):
return cls(channel_size=opt.image_channel_size)

@classmethod
def _check_deps(cls):
if any([Image is None, transforms is None, cv2 is None]):
cls._raise_missing_dep(
"PIL", "torchvision", "cv2")

def read(self, images, side, img_dir=None):
"""
Args:
path (str): location of a src file containing image paths
images (str): location of a src file containing image paths
src_dir (str): location of source images
side (str): 'src' or 'tgt'
truncate: maximum img size ((0,0) or None for unlimited)
Yields:
a dictionary containing image data, path and index for each line.
"""
ImageDataset._check_deps()

if isinstance(images, str):
images = cls._read_file(images)
images = DataReaderBase._read_file(images)

for i, filename in enumerate(images):
filename = filename.decode("utf-8").strip()
img_path = os.path.join(src_dir, filename)
img_path = os.path.join(img_dir, filename)
if not os.path.exists(img_path):
img_path = filename

assert os.path.exists(img_path), \
'img path %s not found' % filename

if channel_size == 1:
if self.channel_size == 1:
img = transforms.ToTensor()(
Image.fromarray(cv2.imread(img_path, 0)))
else:
img = transforms.ToTensor()(Image.open(img_path))
if truncate and truncate != (0, 0):
if not (img.size(1) <= truncate[0]
and img.size(2) <= truncate[1]):
if self.truncate and self.truncate != (0, 0):
if not (img.size(1) <= self.truncate[0]
and img.size(2) <= self.truncate[1]):
continue
yield {side: img, side + '_path': filename, 'indices': i}


class ImageDataset(DatasetBase):
@staticmethod
def sort_key(ex):
""" Sort using the size of the image: (width, height)."""
return ex.src.size(2), ex.src.size(1)


def batch_img(data, vocab):
c = data[0].size(0)
h = max([t.size(1) for t in data])
Expand Down
Loading

0 comments on commit c8dc587

Please sign in to comment.