Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge master back into our fork. #1

Merged
merged 6 commits into from
Dec 19, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 10 additions & 15 deletions onmt/inputters/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,24 +3,19 @@
Inputters implement the logic of transforming raw data to vectorized inputs,
e.g., from a line of text to a sequence of embeddings.
"""
from onmt.inputters.inputter import collect_feature_vocabs, make_features, \
collect_features, get_num_features, \
load_fields_from_vocab, get_fields, \
save_fields_to_vocab, build_dataset, \
build_vocab, merge_vocabs, OrderedIterator
from onmt.inputters.inputter import make_features, collect_features, \
load_fields_from_vocab, get_fields, OrderedIterator, \
save_fields_to_vocab, build_dataset, build_vocab
from onmt.inputters.dataset_base import DatasetBase, PAD_WORD, BOS_WORD, \
EOS_WORD, UNK
from onmt.inputters.text_dataset import TextDataset, ShardedTextCorpusIterator
EOS_WORD
from onmt.inputters.text_dataset import TextDataset
from onmt.inputters.image_dataset import ImageDataset
from onmt.inputters.audio_dataset import AudioDataset, \
ShardedAudioCorpusIterator
from onmt.inputters.audio_dataset import AudioDataset


__all__ = ['PAD_WORD', 'BOS_WORD', 'EOS_WORD', 'UNK', 'DatasetBase',
'collect_feature_vocabs', 'make_features',
'collect_features', 'get_num_features',
__all__ = ['PAD_WORD', 'BOS_WORD', 'EOS_WORD', 'DatasetBase',
'make_features', 'collect_features',
'load_fields_from_vocab', 'get_fields',
'save_fields_to_vocab', 'build_dataset',
'build_vocab', 'merge_vocabs', 'OrderedIterator',
'TextDataset', 'ImageDataset', 'AudioDataset',
'ShardedTextCorpusIterator', 'ShardedAudioCorpusIterator']
'build_vocab', 'OrderedIterator',
'TextDataset', 'ImageDataset', 'AudioDataset']
165 changes: 17 additions & 148 deletions onmt/inputters/audio_dataset.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,21 @@
# -*- coding: utf-8 -*-
import codecs
import os
import sys
import io
from tqdm import tqdm

import torch

from onmt.inputters.dataset_base import NonTextDatasetBase
from onmt.inputters.dataset_base import DatasetBase


class AudioDataset(NonTextDatasetBase):
class AudioDataset(DatasetBase):
data_type = 'audio' # get rid of this class attribute asap

@staticmethod
def sort_key(ex):
""" Sort using duration time of the sound spectrogram. """
return ex.src.size(1)

@staticmethod
def make_audio_examples(path, audio_dir, sample_rate, window_size,
window_stride, window, normalize_audio,
truncate=None):
"""
Args:
path (str): location of a src file containing audio paths.
audio_dir (str): location of source audio files.
sample_rate (int): sample_rate.
window_size (float) : window size for spectrogram in seconds.
window_stride (float): window stride for spectrogram in seconds.
window (str): window type for spectrogram generation.
normalize_audio (bool): subtract spectrogram by mean and divide
by std or not.
truncate (int): maximum audio length (0 or None for unlimited).

Returns:
example_dict iterator
"""
examples_iter = AudioDataset.read_audio_file(
path, audio_dir, "src", sample_rate,
window_size, window_stride, window,
normalize_audio, truncate)

return examples_iter

@staticmethod
def extract_features(audio_path, sample_rate, truncate, window_size,
window_stride, window, normalize_audio):
Expand Down Expand Up @@ -84,10 +56,19 @@ def extract_features(audio_path, sample_rate, truncate, window_size,
spect.div_(std)
return spect

@staticmethod
def read_audio_file(path, src_dir, side, sample_rate, window_size,
window_stride, window, normalize_audio,
truncate=None):
@classmethod
def make_examples(
cls,
path,
src_dir,
side,
sample_rate,
window_size,
window_stride,
window,
normalize_audio,
truncate=None
):
"""
Args:
path (str): location of a src file containing audio paths.
Expand All @@ -104,7 +85,8 @@ def read_audio_file(path, src_dir, side, sample_rate, window_size,
Yields:
a dictionary containing audio data for each line.
"""
assert (src_dir is not None) and os.path.exists(src_dir),\
assert isinstance(path, str), "Iterators not supported for audio"
assert src_dir is not None and os.path.exists(src_dir),\
"src_dir must be a valid directory if data_type is audio"

with codecs.open(path, "r", "utf-8") as corpus_file:
Expand All @@ -123,116 +105,3 @@ def read_audio_file(path, src_dir, side, sample_rate, window_size,

yield {side: spect, side + '_path': line.strip(),
side + '_lengths': spect.size(1), 'indices': i}


class ShardedAudioCorpusIterator(object):
"""
This is the iterator for audio corpus, used for sharding large audio
corpus into small shards, to avoid hogging memory.

Inside this iterator, it automatically divides the audio files into
shards of size `shard_size`. Then, for each shard, it processes
into (example_dict, n_features) tuples when iterates.
"""
def __init__(self, src_dir, corpus_path, truncate, side, shard_size,
sample_rate, window_size, window_stride,
window, normalize_audio=True, assoc_iter=None):
"""
Args:
src_dir: the directory containing audio files
corpus_path: the path containing audio file names
truncate: maximum audio length (0 or None for unlimited).
side: "src" or "tgt".
shard_size: the shard size, 0 means not sharding the file.
sample_rate (int): sample_rate.
window_size (float) : window size for spectrogram in seconds.
window_stride (float): window stride for spectrogram in seconds.
window (str): window type for spectrogram generation.
normalize_audio (bool): subtract spectrogram by mean and divide
by std or not.
assoc_iter: if not None, it is the associate iterator that
this iterator should align its step with.
"""
try:
# The codecs module seems to have bugs with seek()/tell(),
# so we use io.open().
self.corpus = io.open(corpus_path, "r", encoding="utf-8")
except IOError:
sys.stderr.write("Failed to open corpus file: %s" % corpus_path)
sys.exit(1)

self.side = side
self.src_dir = src_dir
self.shard_size = shard_size
self.sample_rate = sample_rate
self.truncate = truncate
self.window_size = window_size
self.window_stride = window_stride
self.window = window
self.normalize_audio = normalize_audio
self.assoc_iter = assoc_iter
self.last_pos = 0
self.last_line_index = -1
self.line_index = -1
self.eof = False

def __iter__(self):
"""
Iterator of (example_dict, nfeats).
On each call, it iterates over as many (example_dict, nfeats) tuples
until this shard's size equals to or approximates `self.shard_size`.
"""
iteration_index = -1
if self.assoc_iter is not None:
# We have associate iterator, just yields tuples
# util we run parallel with it.
while self.line_index < self.assoc_iter.line_index:
line = self.corpus.readline()
assert line != '', "The corpora must have same number of lines"

self.line_index += 1
iteration_index += 1
yield self._example_dict_iter(line, iteration_index)

if self.assoc_iter.eof:
self.eof = True
self.corpus.close()
else:
# Yield tuples until this shard's size reaches the threshold.
self.corpus.seek(self.last_pos)
while True:
if self.shard_size != 0 and self.line_index % 64 == 0:
cur_pos = self.corpus.tell()
if self.line_index \
>= self.last_line_index + self.shard_size:
self.last_pos = cur_pos
self.last_line_index = self.line_index
raise StopIteration

line = self.corpus.readline()
if line == '':
self.eof = True
self.corpus.close()
raise StopIteration

self.line_index += 1
iteration_index += 1
yield self._example_dict_iter(line, iteration_index)

def hit_end(self):
return self.eof

def _example_dict_iter(self, line, index):
line = line.strip()
audio_path = os.path.join(self.src_dir, line)
if not os.path.exists(audio_path):
audio_path = line

assert os.path.exists(audio_path), 'audio path %s not found' % line

spect = AudioDataset.extract_features(
audio_path, self.sample_rate, self.truncate, self.window_size,
self.window_stride, self.window, self.normalize_audio
)
return {self.side: spect, self.side + '_path': line,
self.side + '_lengths': spect.size(1), 'indices': index}
81 changes: 50 additions & 31 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# coding: utf-8

from itertools import chain
from collections import Counter

import torch
import torchtext
from torchtext.vocab import Vocab

PAD_WORD = '<blank>'
UNK_WORD = '<unk>'
UNK = 0
BOS_WORD = '<s>'
EOS_WORD = '</s>'

Expand Down Expand Up @@ -36,6 +37,37 @@ def __reduce_ex__(self, proto):
# This is a hack. Something is broken with torch pickle.
return super(DatasetBase, self).__reduce_ex__()

def __init__(self, fields, src_examples_iter, tgt_examples_iter,
dynamic_dict=False, filter_pred=None):

# Each element of an example is a dictionary whose keys represents
# at minimum the src tokens and their indices and potentially also
# the src and tgt features and alignment information.
if tgt_examples_iter is not None:
examples_iter = (self._join_dicts(src, tgt) for src, tgt in
zip(src_examples_iter, tgt_examples_iter))
else:
examples_iter = src_examples_iter

# self.src_vocabs is used in collapse_copy_scores and in Translator.py
self.src_vocabs = []
if dynamic_dict:
unk, pad = fields['src'].unk_token, fields['src'].pad_token
examples_iter = (self._dynamic_dict(ex, unk, pad)
for ex in examples_iter)

# Peek at the first to see which fields are used.
ex, examples_iter = self._peek(examples_iter)
keys = ex.keys()

# why do we need to use different keys from the ones passed in?
fields = [(k, fields[k]) if k in fields else (k, None) for k in keys]
example_values = ([ex[k] for k in keys] for ex in examples_iter)
examples = [self._construct_example_fromlist(ex_values, fields)
for ex_values in example_values]

super(DatasetBase, self).__init__(examples, fields, filter_pred)

def save(self, path, remove_fields=True):
if remove_fields:
self.fields = []
Expand Down Expand Up @@ -108,6 +140,7 @@ def _construct_example_fromlist(self, data, fields):
Returns:
the created `Example` object.
"""
# why does this exist?
ex = torchtext.data.Example()
for (name, field), val in zip(fields, data):
if field is not None:
Expand All @@ -116,33 +149,19 @@ def _construct_example_fromlist(self, data, fields):
setattr(ex, name, val)
return ex


# this is just temporary until the TextDatabase can be unified with the others
class NonTextDatasetBase(DatasetBase):
"""
Args:
fields (dict): a dictionary of `torchtext.data.Field`.
src_examples_iter (dict iter): preprocessed source example
dictionary iterator.
tgt_examples_iter (dict iter): preprocessed target example
dictionary iterator.
tgt_seq_length (int): maximum target sequence length.
"""
def __init__(self, fields, src_examples_iter, tgt_examples_iter,
filter_pred=None):
if tgt_examples_iter is not None:
examples_iter = (self._join_dicts(src, tgt) for src, tgt in
zip(src_examples_iter, tgt_examples_iter))
else:
examples_iter = src_examples_iter

# Peek at the first to see which fields are used.
ex, examples_iter = self._peek(examples_iter)
keys = ex.keys()

fields = [(k, fields[k]) if k in fields else (k, None) for k in keys]
example_values = ([ex[k] for k in keys] for ex in examples_iter)
examples = [self._construct_example_fromlist(ex_values, fields)
for ex_values in example_values]

super(DatasetBase, self).__init__(examples, fields, filter_pred)
def _dynamic_dict(self, example, unk, pad):
# it would not be necessary to pass unk and pad if the method were
# called after fields becomes an attribute of self
src = example["src"]
src_vocab = Vocab(Counter(src), specials=[unk, pad])
self.src_vocabs.append(src_vocab)
# Map source tokens to indices in the dynamic dict.
src_map = torch.LongTensor([src_vocab.stoi[w] for w in src])
example["src_map"] = src_map

if "tgt" in example:
tgt = example["tgt"]
mask = torch.LongTensor(
[0] + [src_vocab.stoi[w] for w in tgt] + [0])
example["alignment"] = mask
return example
Loading