Skip to content

Commit

Permalink
Better documentation for data pipeline (OpenNMT#1269)
Browse files Browse the repository at this point in the history
* Partial update to docs.

* Improve inputter documentation.

* Reverse FileNotFoundError instead of RuntimeError.
  • Loading branch information
flauted authored and vince62s committed Feb 7, 2019
1 parent bf38b6f commit e73991a
Show file tree
Hide file tree
Showing 6 changed files with 274 additions and 89 deletions.
38 changes: 26 additions & 12 deletions onmt/inputters/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,22 @@


class AudioDataReader(DataReaderBase):
"""
"""Read audio data from disk.
Args:
sample_rate (int): sample_rate.
window_size (float) : window size for spectrogram in seconds.
window_stride (float): window stride for spectrogram in seconds.
window (str): window type for spectrogram generation.
window (str): window type for spectrogram generation. See
:func:`librosa.stft()` ``window`` for more details.
normalize_audio (bool): subtract spectrogram by mean and divide
by std or not.
truncate (int or NoneType): maximum audio length
(0 or None for unlimited).
Raises:
onmt.inputters.datareader_base.MissingDependencyException: If
importing any of ``torchaudio``, ``librosa``, or ``numpy`` fail.
"""

def __init__(self, sample_rate=0, window_size=0, window_stride=0,
Expand Down Expand Up @@ -89,15 +95,21 @@ def extract_features(self, audio_path):
return spect

def read(self, data, side, src_dir=None):
"""
"""Read data into dicts.
Args:
data: sequence of audio paths or path containing these sequences
src_dir (str): location of source audio files.
side (str): 'src' or 'tgt'.
data (str or Iterable[str]): Sequence of audio paths or
path to file containing audio paths.
In either case, the filenames may be relative to ``src_dir``
(default behavior) or absolute.
side (str): Prefix used in return dict. Usually
``"src"`` or ``"tgt"``.
src_dir (str): Location of source audio files. See ``data``.
Yields:
a dictionary containing audio data for each line.
A dictionary containing audio data for each line.
"""

assert src_dir is not None and os.path.exists(src_dir),\
"src_dir must be a valid directory if data_type is audio"

Expand All @@ -120,7 +132,7 @@ def read(self, data, side, src_dir=None):
class AudioDataset(DatasetBase):
@staticmethod
def sort_key(ex):
""" Sort using duration time of the sound spectrogram. """
"""Sort using duration time of the sound spectrogram."""
return ex.src.size(1)


Expand All @@ -129,6 +141,7 @@ class AudioSeqField(Field):
See :class:`Fields` for attribute descriptions.
"""

def __init__(self, preprocessing=None, postprocessing=None,
include_lengths=False, batch_first=False, pad_index=0,
is_target=False):
Expand All @@ -146,7 +159,7 @@ def pad(self, minibatch):
"""Pad a batch of examples to the length of the longest example.
Args:
minibatch (list[torch.FloatTensor]): A list of audio data,
minibatch (List[torch.FloatTensor]): A list of audio data,
each having shape 1 x n_feats x len where len is variable.
Returns:
Expand All @@ -155,6 +168,7 @@ def pad(self, minibatch):
and a list of the lengths if `self.include_lengths` is `True`
else just returns the padded tensor.
"""

assert not self.pad_first and not self.truncate_first \
and not self.fix_length and self.sequential
minibatch = list(minibatch)
Expand All @@ -171,18 +185,19 @@ def pad(self, minibatch):
def numericalize(self, arr, device=None):
"""Turn a batch of examples that use this field into a Variable.
If the field has include_lengths=True, a tensor of lengths will be
If the field has ``include_lengths=True``, a tensor of lengths will be
included in the return value.
Args:
arr (torch.FloatTensor, or Tuple(torch.FloatTensor, List[int])):
arr (torch.FloatTensor or Tuple(torch.FloatTensor, List[int])):
List of tokenized and padded examples, or tuple of List of
tokenized and padded examples and List of lengths of each
example if self.include_lengths is True. Examples have shape
batch_size x 1 x n_feats x max_len if `self.batch_first`
else max_len x batch_size x 1 x n_feats.
device (str or torch.device): See `Field.numericalize`.
"""

assert self.use_vocab is False
if self.include_lengths and not isinstance(arr, tuple):
raise ValueError("Field has include_lengths set to True, but "
Expand All @@ -207,5 +222,4 @@ def numericalize(self, arr, device=None):

def audio_fields(base_name, **kwargs):
audio = AudioSeqField(pad_index=0, batch_first=True, include_lengths=True)

return [(base_name, audio)]
16 changes: 15 additions & 1 deletion onmt/inputters/datareader_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,26 @@ class MissingDependencyException(Exception):


class DataReaderBase(object):
"""Read data from file system and yield as dicts."""
"""Read data from file system and yield as dicts.
Raises:
MissingDependencyException: A number of DataReaders need specific
additional packages. If any are missing, this will be raised.
"""

@classmethod
def from_opt(cls, opt):
"""Alternative constructor.
Args:
opt (argparse.Namespace): The parsed arguments.
"""

return cls()

@classmethod
def _read_file(cls, path):
"""Line-by-line read a file as bytes."""
with open(path, "rb") as f:
for line in f:
yield line
Expand All @@ -27,4 +40,5 @@ def _raise_missing_dep(*missing_deps):
"the following dependencies: " + ", ".join(missing_deps))

def read(self, data, side, src_dir):
"""Read data from file system and yield as dicts."""
raise NotImplementedError()
1 change: 1 addition & 0 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ class DatasetBase(Dataset):
Datasets in OpenNMT take three positional arguments:
Args:
`fields`: a dict with the structure returned by inputters.get_fields().
keys match the keys of items yielded by the src_examples_iter or
tgt_examples_iter, while values are lists of (name, Field) pairs.
Expand Down
29 changes: 21 additions & 8 deletions onmt/inputters/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,16 @@


class ImageDataReader(DataReaderBase):
"""
"""Read image data from disk.
Args:
truncate: maximum img size ((0,0) or None for unlimited)
channel_size: Number of channels per image.
truncate (tuple[int] or NoneType): maximum img size. Use
``(0,0)`` or ``None`` for unlimited.
channel_size (int): Number of channels per image.
Raises:
onmt.inputters.datareader_base.MissingDependencyException: If
importing any of ``PIL``, ``torchvision``, or ``cv2`` fail.
"""

def __init__(self, truncate=None, channel_size=3):
Expand All @@ -40,11 +46,17 @@ def _check_deps(cls):
"PIL", "torchvision", "cv2")

def read(self, images, side, img_dir=None):
"""
"""Read data into dicts.
Args:
images (str): location of a src file containing image paths
src_dir (str): location of source images
side (str): 'src' or 'tgt'
images (str or Iterable[str]): Sequence of image paths or
path to file containing audio paths.
In either case, the filenames may be relative to ``src_dir``
(default behavior) or absolute.
side (str): Prefix used in return dict. Usually
``"src"`` or ``"tgt"``.
img_dir (str): Location of source image files. See ``images``.
Yields:
a dictionary containing image data, path and index for each line.
"""
Expand Down Expand Up @@ -75,11 +87,12 @@ def read(self, images, side, img_dir=None):
class ImageDataset(DatasetBase):
@staticmethod
def sort_key(ex):
""" Sort using the size of the image: (width, height)."""
"""Sort using the size of the image: (width, height)."""
return ex.src.size(2), ex.src.size(1)


def batch_img(data, vocab):
"""Pad and batch a sequence of images."""
c = data[0].size(0)
h = max([t.size(1) for t in data])
w = max([t.size(2) for t in data])
Expand Down
Loading

0 comments on commit e73991a

Please sign in to comment.