Skip to content

Add functionals gain, dither, scale_to_interval #319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Dec 2, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 78 additions & 1 deletion test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import torch
import torchaudio
import torchaudio.functional as F
import torchaudio.transforms as T
import pytest
import unittest
import common_utils
Expand All @@ -31,8 +32,10 @@ class TestFunctional(unittest.TestCase):
specgram = torch.tensor([1., 2., 3., 4.])

test_dirpath, test_dir = common_utils.create_temp_assets_dir()

test_filepath = os.path.join(test_dirpath, 'assets',
'steam-train-whistle-daniel_simon.mp3')
waveform_train, sr_train = torchaudio.load(test_filepath)

def test_torchscript_spectrogram(self):

Expand Down Expand Up @@ -365,8 +368,63 @@ def test_create_fb(self):
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)

def test_pitch(self):
def test_gain(self):
waveform_gain = F.gain(self.waveform_train, 3)
self.assertTrue(waveform_gain.abs().max().item(), 1.)

E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("gain", [3])
sox_gain_waveform = E.sox_build_flow_effects()[0]

self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform, atol=1e-04))

def test_scale_to_interval(self):
scaled = 5.5 # [-5.5, 5.5]
waveform_scaled = F._scale_to_interval(self.waveform_train, scaled)

self.assertTrue(torch.max(waveform_scaled) <= scaled)
self.assertTrue(torch.min(waveform_scaled) >= -scaled)

def test_dither(self):
waveform_dithered = F.dither(self.waveform_train)
waveform_dithered_noiseshaped = F.dither(self.waveform_train, noise_shaping=True)

E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(self.test_filepath)
E.append_effect_to_chain("dither", [])
sox_dither_waveform = E.sox_build_flow_effects()[0]

self.assertTrue(torch.allclose(waveform_dithered, sox_dither_waveform, atol=1e-04))
E.clear_chain()

E.append_effect_to_chain("dither", ["-s"])
sox_dither_waveform_ns = E.sox_build_flow_effects()[0]

self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, atol=1e-02))

def test_vctk_transform_pipeline(self):
test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav")
wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)

# rate
sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation')
wf_vctk = sample(wf_vctk)
# dither
wf_vctk = F.dither(wf_vctk, noise_shaping=True)

E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(test_filepath_vctk)
E.append_effect_to_chain("gain", ["-h"])
E.append_effect_to_chain("channels", [1])
E.append_effect_to_chain("rate", [16000])
E.append_effect_to_chain("gain", ["-rh"])
E.append_effect_to_chain("dither", ["-s"])
wf_vctk_sox = E.sox_build_flow_effects()[0]

self.assertTrue(torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03))

def test_pitch(self):
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
Expand Down Expand Up @@ -518,6 +576,25 @@ def test_mask_along_axis_iid(self):

_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)

def test_torchscript_gain(self):
tensor = torch.rand((1, 1000))
gainDB = 2.0

_test_torchscript_functional(F.gain, tensor, gainDB)

def test_torchscript_scale_to_interval(self):
tensor = torch.rand((1, 1000))
scaled = 3.5

_test_torchscript_functional(F._scale_to_interval, tensor, scaled)

def test_torchscript_dither(self):
tensor = torch.rand((1, 1000))

_test_torchscript_functional(F.dither, tensor)
_test_torchscript_functional(F.dither, tensor, "RPDF")
_test_torchscript_functional(F.dither, tensor, "GPDF")


@pytest.mark.parametrize('complex_tensor', [
torch.randn(1, 2, 1025, 400, 2),
Expand Down
20 changes: 9 additions & 11 deletions torchaudio/datasets/vctk.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,18 +21,16 @@ def load_vctk_item(

# Read wav
file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio)
waveform, sample_rate = torchaudio.load(file_audio)
if downsample:
# Legacy
E = torchaudio.sox_effects.SoxEffectsChain()
E.set_input_file(file_audio)
E.append_effect_to_chain("gain", ["-h"])
E.append_effect_to_chain("channels", [1])
E.append_effect_to_chain("rate", [16000])
E.append_effect_to_chain("gain", ["-rh"])
E.append_effect_to_chain("dither", ["-s"])
waveform, sample_rate = E.sox_build_flow_effects()
else:
waveform, sample_rate = torchaudio.load(file_audio)
# TODO Remove this parameter after deprecation
F = torchaudio.functional
T = torchaudio.transforms
# rate
sample = T.Resample(sample_rate, 16000, resampling_method='sinc_interpolation')
waveform = sample(waveform)
# dither
waveform = F.dither(waveform, noise_shaping=True)

return waveform, sample_rate, utterance, speaker_id, utterance_id

Expand Down
156 changes: 156 additions & 0 deletions torchaudio/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,6 +858,162 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
return output


def gain(waveform, gain_db=1.0):
# type: (Tensor, float) -> Tensor
r"""Apply amplification or attenuation to the whole waveform.

Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).

Returns:
torch.Tensor: the whole waveform amplified by gain_db.
"""
if (gain_db == 0):
return waveform

ratio = 10 ** (gain_db / 20)

return waveform * ratio


def _scale_to_interval(waveform, interval_max=1.0):
# type: (Tensor, float) -> Tensor
r"""Scale the waveform to the interval [-interval_max, interval_max] across all dimensions.

Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
interval_max (float): The bounds of the interval, where the float indicates
the upper bound and the negative of the float indicates the lower
bound (Default: `1.0`).
Example: interval=1.0 -> [-1.0, 1.0]

Returns:
torch.Tensor: the whole waveform scaled to interval.
"""
abs_max = torch.max(torch.abs(waveform))
ratio = abs_max / interval_max
waveform /= ratio

return waveform


def _add_noise_shaping(dithered_waveform, waveform):
r"""Noise shaping is calculated by error:
error[n] = dithered[n] - original[n]
noise_shaped_waveform[n] = dithered[n] + error[n-1]
"""
wf_shape = waveform.size()
waveform = waveform.reshape(-1, wf_shape[-1])

dithered_shape = dithered_waveform.size()
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])

error = dithered_waveform - waveform

# add error[n-1] to dithered_waveform[n], so offset the error by 1 index
for index in range(error.size()[0]):
err = error[index]
error_offset = torch.cat((torch.zeros(1), err))
error[index] = error_offset[:waveform.size()[1]]

noise_shaped = dithered_waveform + error
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])


def _apply_probability_distribution(waveform, density_function="TPDF"):
# type: (Tensor, str) -> Tensor
r"""Apply a probability distribution function on a waveform.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should just be an internal function, and it in fact does the core of applying dither. Let's rename this to _apply_dither.

nit: "Apply dither to the waveform using the chosen density function."

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see this comment marked as resolved. What was the resolution of this?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just pushed the change, I renamed it to _apply_probability_distribution. How does that sound?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And this comment... should I still rename to "apply dither"?


Triangular probability density function (TPDF) dither noise has a
triangular distribution; values in the center of the range have a higher
probability of occurring.

Rectangular probability density function (RPDF) dither noise has a
uniform distribution; any value in the specified range has the same
probability of occurring.

Gaussian probability density function (GPDF) has a normal distribution.
The relationship of probabilities of results follows a bell-shaped,
or Gaussian curve, typical of dither generated by analog sources.
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
probability_density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF`
Returns:
torch.Tensor: waveform dithered with TPDF
"""
shape = waveform.size()
waveform = waveform.reshape(-1, shape[-1])

channel_size = waveform.size()[0] - 1
time_size = waveform.size()[-1] - 1

random_channel = int(torch.randint(channel_size, [1, ]).item()) if channel_size > 0 else 0
random_time = int(torch.randint(time_size, [1, ]).item()) if time_size > 0 else 0

number_of_bits = 16
up_scaling = 2 ** (number_of_bits - 1) - 2
signal_scaled = waveform * up_scaling
down_scaling = 2 ** (number_of_bits - 1)

signal_scaled_dis = waveform
if (density_function == "RPDF"):
RPDF = waveform[random_channel][random_time] - 0.5

signal_scaled_dis = signal_scaled + RPDF
elif (density_function == "GPDF"):
# TODO Replace by distribution code once
# https://github.com/pytorch/pytorch/issues/29843 is resolved
# gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()

num_rand_variables = 6

gaussian = waveform[random_channel][random_time]
for ws in num_rand_variables * [time_size]:
rand_chan = int(torch.randint(channel_size, [1, ]).item())
gaussian += waveform[rand_chan][int(torch.randint(ws, [1, ]).item())]

signal_scaled_dis = signal_scaled + gaussian
else:
TPDF = torch.bartlett_window(time_size + 1)
TPDF = TPDF.repeat((channel_size + 1), 1)
signal_scaled_dis = signal_scaled + TPDF

quantised_signal_scaled = torch.round(signal_scaled_dis)
quantised_signal = quantised_signal_scaled / down_scaling
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])


def dither(waveform, density_function="TPDF", noise_shaping=False):
# type: (Tensor, str, bool) -> Tensor
r"""Dither increases the perceived dynamic range of audio stored at a
particular bit-depth by eliminating nonlinear truncation distortion
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
Args:
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
density_function (string): The density function of a
continuous random variable (Default: `TPDF`)
Options: Triangular Probability Density Function - `TPDF`
Rectangular Probability Density Function - `RPDF`
Gaussian Probability Density Function - `GPDF`
noise_shaping (boolean): a filtering process that shapes the spectral
energy of quantisation error (Default: `False`)

Returns:
torch.Tensor: waveform dithered
"""
dithered = _apply_probability_distribution(waveform, density_function=density_function)

if noise_shaping:
return _add_noise_shaping(dithered, waveform)
else:
return dithered


def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
# type: (Tensor, int, float, int) -> Tensor
r"""
Expand Down