Skip to content

Commit

Permalink
Smarter optional dependency imports in the datasets. (OpenNMT#1195)
Browse files Browse the repository at this point in the history
  • Loading branch information
flauted authored and vince62s committed Jan 22, 2019
1 parent d94e6af commit a9b15f7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 6 deletions.
17 changes: 14 additions & 3 deletions onmt/inputters/audio_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,21 @@

from onmt.inputters.dataset_base import DatasetBase

# imports of datatype-specific dependencies
try:
import torchaudio
import librosa
import numpy as np
except ImportError:
torchaudio, librosa, np = None, None, None


class AudioDataset(DatasetBase):
@staticmethod
def _check_deps():
if any([torchaudio is None, librosa is None, np is None]):
AudioDataset._raise_missing_dep(
"torchaudio", "librosa", "numpy")

@staticmethod
def sort_key(ex):
Expand All @@ -17,13 +30,11 @@ def sort_key(ex):
@staticmethod
def extract_features(audio_path, sample_rate, truncate, window_size,
window_stride, window, normalize_audio):
import torchaudio
import librosa
import numpy as np
# torchaudio loading options recently changed. It's probably
# straightforward to rewrite the audio handling to make use of
# up-to-date torchaudio, but in the meantime there is a legacy
# method which uses the old defaults
AudioDataset._check_deps()
sound, sample_rate_ = torchaudio.legacy.load(audio_path)
if truncate and truncate > 0:
if sound.size(0) > truncate:
Expand Down
13 changes: 13 additions & 0 deletions onmt/inputters/dataset_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,12 @@
from torchtext.vocab import Vocab


# several data readers need optional dependencies. There's no
# appropriate builtin exception
class MissingDependencyException(Exception):
pass


class DatasetBase(Dataset):
"""
A dataset is an object that accepts sequences of raw data (sentence pairs
Expand Down Expand Up @@ -80,6 +86,13 @@ def __init__(self, fields, src_examples_iter, tgt_examples_iter,

super(DatasetBase, self).__init__(examples, fields, filter_pred)

@staticmethod
def _raise_missing_dep(*missing_deps):
"""Raise missing dep exception with standard error message."""
raise MissingDependencyException(
"Could not create reader. Be sure to install "
"the following dependencies: " + ", ".join(missing_deps))

def __getattr__(self, attr):
# avoid infinite recursion when fields isn't defined
if 'fields' not in vars(self):
Expand Down
17 changes: 14 additions & 3 deletions onmt/inputters/image_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,21 @@

from onmt.inputters.dataset_base import DatasetBase

# domain specific dependencies
try:
from PIL import Image
from torchvision import transforms
import cv2
except ImportError:
Image, transforms, cv2 = None, None, None


class ImageDataset(DatasetBase):
@staticmethod
def _check_deps():
if any([Image is None, transforms is None, cv2 is None]):
ImageDataset._raise_missing_dep(
"PIL", "torchvision", "cv2")

@staticmethod
def sort_key(ex):
Expand All @@ -25,9 +38,7 @@ def make_examples(
Yields:
a dictionary containing image data, path and index for each line.
"""
from PIL import Image
from torchvision import transforms
import cv2
ImageDataset._check_deps()

if isinstance(images, str):
images = cls._read_file(images)
Expand Down

0 comments on commit a9b15f7

Please sign in to comment.