Skip to content

Move import _torch_sox inside function calls #361

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 38 additions & 0 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,32 @@
import os


class AudioBackendScope:
def __init__(self, backend):
self.new_backend = backend
self.previous_backend = torchaudio.get_audio_backend()

def __enter__(self):
torchaudio.set_audio_backend(self.new_backend)
return self.new_backend

def __exit__(self, type, value, traceback):
backend = self.previous_backend
torchaudio.set_audio_backend(backend)


class Test_LoadSave(unittest.TestCase):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath = os.path.join(test_dirpath, "assets",
"steam-train-whistle-daniel_simon.mp3")

def test_1_save(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_1_save()

def _test_1_save(self):
# load signal
x, sr = torchaudio.load(self.test_filepath, normalization=False)

Expand Down Expand Up @@ -78,6 +98,12 @@ def test_1_save(self):
os.unlink(new_filepath)

def test_2_load(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_2_load()

def _test_2_load(self):
# check normal loading
x, sr = torchaudio.load(self.test_filepath)
self.assertEqual(sr, 44100)
Expand Down Expand Up @@ -117,6 +143,12 @@ def test_2_load(self):
torchaudio.load(tdir)

def test_3_load_and_save_is_identity(self):
for backend in ["sox", "soundfile"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_3_load_and_save_is_identity()

def _test_3_load_and_save_is_identity(self):
input_path = os.path.join(self.test_dirpath, 'assets', 'sinewave.wav')
tensor, sample_rate = torchaudio.load(input_path)
output_path = os.path.join(self.test_dirpath, 'test.wav')
Expand All @@ -127,6 +159,12 @@ def test_3_load_and_save_is_identity(self):
os.unlink(output_path)

def test_4_load_partial(self):
for backend in ["sox"]:
with self.subTest():
with AudioBackendScope(backend):
self._test_4_load_partial()

def _test_4_load_partial(self):
num_frames = 101
offset = 201
# load entire mono sinewave wav file, load a partial copy and then compare
Expand Down
151 changes: 112 additions & 39 deletions torchaudio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,48 @@
import os.path

import torch
import _torch_sox

from torchaudio import transforms, datasets, kaldi_io, sox_effects, compliance
from torchaudio import (
_soundfile_backend,
_sox_backend,
compliance,
datasets,
kaldi_io,
sox_effects,
transforms,
)

try:
from .version import __version__, git_version # noqa: F401
except ImportError:
pass


_audio_backend = "sox"
_audio_backends = ["sox", "soundfile"]


def set_audio_backend(backend):
"""
Specifies the package used to load.
Args:
backend (string): Name of the backend. One of {}.
""".format(_audio_backends)
global _audio_backend
if backend not in _audio_backends:
raise ValueError(
"Invalid backend '{}'. Options are {}.".format(backend, _audio_backends)
)
_audio_backend = backend


def get_audio_backend():
"""
Gets the name of the package used to load.
"""
return _audio_backend


def check_input(src):
if not torch.is_tensor(src):
raise TypeError('Expected a tensor, got %s' % type(src))
Expand Down Expand Up @@ -67,36 +99,25 @@ def load(filepath,
1.

"""
# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)
# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))

# initialize output tensor
if out is not None:
check_input(out)
else:
out = torch.FloatTensor()

if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if offset < 0:
raise ValueError("Expected positive offset value")

sample_rate = _torch_sox.read_audio_file(filepath,
out,
channels_first,
num_frames,
offset,
signalinfo,
encodinginfo,
filetype)

# normalize if needed
_audio_normalization(out, normalization)
if get_audio_backend() == "sox":
func = _sox_backend.load
elif get_audio_backend() == "soundfile":
func = _soundfile_backend.load
else:
raise ImportError

return out, sample_rate
return func(
filepath,
out=out,
normalization=normalization,
channels_first=channels_first,
num_frames=num_frames,
offset=offset,
signalinfo=signalinfo,
encodinginfo=encodinginfo,
filetype=filetype,
)


def load_wav(filepath, **kwargs):
Expand Down Expand Up @@ -128,13 +149,17 @@ def save(filepath, src, sample_rate, precision=16, channels_first=True):
channels_first (bool): Set channels first or length first in result. (
Default: ``True``)
"""
si = sox_signalinfo_t()
ch_idx = 0 if channels_first else 1
si.rate = sample_rate
si.channels = 1 if src.dim() == 1 else src.size(ch_idx)
si.length = src.numel()
si.precision = precision
return save_encinfo(filepath, src, channels_first, si)

if get_audio_backend() == "sox":
func = _sox_backend.save
elif get_audio_backend() == "soundfile":
func = _soundfile_backend.save
else:
raise ImportError

return func(
filepath, src, sample_rate, precision=precision, channels_first=channels_first
)


def save_encinfo(filepath,
Expand Down Expand Up @@ -203,7 +228,12 @@ def save_encinfo(filepath,
src = src.transpose(1, 0)
# save data to file
src = src.contiguous()
_torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)

if get_audio_backend() == "sox":
import _torch_sox
_torch_sox.write_audio_file(filepath, src, signalinfo, encodinginfo, filetype)
else:
raise ImportError


def info(filepath):
Expand All @@ -220,7 +250,15 @@ def info(filepath):
>>> si, ei = torchaudio.info('foo.wav')
>>> rate, channels, encoding = si.rate, si.channels, ei.encoding
"""
return _torch_sox.get_info(filepath)

if get_audio_backend() == "sox":
func = _sox_backend.info
elif get_audio_backend() == "soundfile":
func = _soundfile_backend.info
else:
raise ImportError

return func(filepath)


def sox_signalinfo_t():
Expand All @@ -242,6 +280,11 @@ def sox_signalinfo_t():
>>> si.precision = 16
>>> si.length = 0
"""

if get_audio_backend() != "sox":
raise ImportError

import _torch_sox
return _torch_sox.sox_signalinfo_t()


Expand Down Expand Up @@ -274,6 +317,11 @@ def sox_encodinginfo_t():
>>> ei.opposite_endian = torchaudio.get_sox_bool(0)

"""

if get_audio_backend() != "sox":
raise ImportError

import _torch_sox
ei = _torch_sox.sox_encodinginfo_t()
sdo = get_sox_option_t(2) # sox_default_option
ei.reverse_bytes = sdo
Expand All @@ -292,6 +340,11 @@ def get_sox_encoding_t(i=None):
Returns:
sox_encoding_t: A sox_encoding_t type for output encoding
"""

if get_audio_backend() != "sox":
raise ImportError

import _torch_sox
if i is None:
# one can see all possible values using the .__members__ attribute
return _torch_sox.sox_encoding_t
Expand All @@ -309,6 +362,11 @@ def get_sox_option_t(i=2):
Returns:
sox_option_t: A sox_option_t type
"""

if get_audio_backend() != "sox":
raise ImportError

import _torch_sox
if i is None:
return _torch_sox.sox_option_t
else:
Expand All @@ -326,6 +384,11 @@ def get_sox_bool(i=0):
Returns:
sox_bool: A sox_bool type
"""

if get_audio_backend() != "sox":
raise ImportError

import _torch_sox
if i is None:
return _torch_sox.sox_bool
else:
Expand All @@ -337,13 +400,23 @@ def initialize_sox():
loading. Importantly, only run `initialize_sox` once and do not shutdown
after each effect chain, but rather once you are finished with all effects chains.
"""

if get_audio_backend() != "sox":
raise ImportError

import _torch_sox
return _torch_sox.initialize_sox()


def shutdown_sox():
"""Showdown sox for effects chain. Not required for simple loading. Importantly,
only call once. Attempting to re-initialize sox will result in seg faults.
"""

if get_audio_backend() != "sox":
raise ImportError

import _torch_sox
return _torch_sox.shutdown_sox()


Expand Down
74 changes: 74 additions & 0 deletions torchaudio/_soundfile_backend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import os

import torch


def check_input(src):
if not torch.is_tensor(src):
raise TypeError("Expected a tensor, got %s" % type(src))
if src.is_cuda:
raise TypeError("Expected a CPU based tensor, got %s" % type(src))


def load(
filepath,
out=None,
normalization=True,
channels_first=True,
num_frames=0,
offset=0,
filetype=None,
**_,
):
r"""See torchaudio.load"""

# stringify if `pathlib.Path` (noop if already `str`)
filepath = str(filepath)

# check if valid file
if not os.path.isfile(filepath):
raise OSError("{} not found or is a directory".format(filepath))

if num_frames < -1:
raise ValueError("Expected value for num_samples -1 (entire file) or >=0")
if num_frames == 0:
num_frames = -1
if offset < 0:
raise ValueError("Expected positive offset value")

import soundfile

# initialize output tensor
# TODO remove pysoundfile and call directly soundfile to avoid going through numpy
if out is not None:
check_input(out)
_, sample_rate = soundfile.read(
filepath, frames=num_frames, start=offset, always_2d=True, out=out
)
else:
out, sample_rate = soundfile.read(
filepath, frames=num_frames, start=offset, always_2d=True
)
out = torch.tensor(out).t()

# normalize if needed
# _audio_normalization(out, normalization)

return out, sample_rate


def save(filepath, src, sample_rate, channels_first=True, **_):
r"""See torchaudio.save"""

if channels_first:
src = src.t()

import soundfile
return soundfile.write(filepath, src, sample_rate)


def info(filepath, **_):
r"""See torchaudio.info"""

import soundfile
return soundfile.info(filepath)
Loading