Skip to content

Commit 2d601c2

Browse files
committed
Initial commit for SoX logic in VCTK
1 parent bdf9255 commit 2d601c2

File tree

3 files changed

+245
-12
lines changed

3 files changed

+245
-12
lines changed

test/test_functional.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import torch
66
import torchaudio
77
import torchaudio.functional as F
8+
import torchaudio.transforms as T
89
import pytest
910
import unittest
1011
import common_utils
@@ -31,6 +32,10 @@ class TestFunctional(unittest.TestCase):
3132
specgram = torch.tensor([1., 2., 3., 4.])
3233

3334
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
35+
36+
test_filepath_sinewave = os.path.join(test_dirpath, "assets", "sinewave.wav")
37+
waveform_sinewave, sr_sinewave = torchaudio.load(test_filepath_sinewave)
38+
3439
test_filepath = os.path.join(test_dirpath, 'assets',
3540
'steam-train-whistle-daniel_simon.mp3')
3641

@@ -365,8 +370,63 @@ def test_create_fb(self):
365370
self._test_create_fb(n_mels=56, fmin=1900.0, fmax=900.0)
366371
self._test_create_fb(n_mels=10, fmin=1900.0, fmax=900.0)
367372

368-
def test_pitch(self):
373+
def test_gain(self):
374+
waveform_gain = F.gain(self.waveform_sinewave, 5)
375+
self.assertTrue(waveform_gain.abs().max().item(), 1.)
376+
377+
E = torchaudio.sox_effects.SoxEffectsChain()
378+
E.set_input_file(self.test_filepath_sinewave)
379+
E.append_effect_to_chain("gain", [5])
380+
sox_gain_waveform = E.sox_build_flow_effects()[0]
381+
382+
self.assertTrue(torch.allclose(waveform_gain, sox_gain_waveform))
383+
384+
def test_scale_to_interval(self):
385+
scaled = 5.5 # [-5.5, 5.5]
386+
waveform_scaled = F.scale_to_interval(self.waveform_sinewave, scaled)
387+
388+
self.assertTrue(torch.max(waveform_scaled) <= scaled)
389+
self.assertTrue(torch.min(waveform_scaled) >= -scaled)
390+
391+
def test_dither(self):
392+
waveform_dithered = F.dither(self.waveform_sinewave)
393+
waveform_dithered_noiseshaped = F.dither(self.waveform_sinewave, noise_shaping=True)
394+
395+
E = torchaudio.sox_effects.SoxEffectsChain()
396+
E.set_input_file(self.test_filepath_sinewave)
397+
E.append_effect_to_chain("dither", [])
398+
sox_dither_waveform = E.sox_build_flow_effects()[0]
399+
400+
self.assertTrue(torch.allclose(waveform_dithered, sox_dither_waveform, rtol=1e-03, atol=1e-03))
401+
E.clear_chain()
402+
403+
E.append_effect_to_chain("dither", ["-s"])
404+
sox_dither_waveform_ns = E.sox_build_flow_effects()[0]
405+
406+
self.assertTrue(torch.allclose(waveform_dithered_noiseshaped, sox_dither_waveform_ns, rtol=1e-03, atol=1e-03))
407+
408+
def test_vctk_transform_pipeline(self):
409+
test_filepath_vctk = os.path.join(self.test_dirpath, "assets/VCTK-Corpus/wav48/p224/", "p224_002.wav")
410+
wf_vctk, sr_vctk = torchaudio.load(test_filepath_vctk)
369411

412+
# rate
413+
sample = T.Resample(sr_vctk, 16000, resampling_method='sinc_interpolation')
414+
wf_vctk = sample(wf_vctk)
415+
# dither
416+
wf_vctk = F.dither(wf_vctk, noise_shaping=True)
417+
418+
E = torchaudio.sox_effects.SoxEffectsChain()
419+
E.set_input_file(test_filepath_vctk)
420+
E.append_effect_to_chain("gain", ["-h"])
421+
E.append_effect_to_chain("channels", [1])
422+
E.append_effect_to_chain("rate", [16000])
423+
E.append_effect_to_chain("gain", ["-rh"])
424+
E.append_effect_to_chain("dither", ["-s"])
425+
wf_vctk_sox = E.sox_build_flow_effects()[0]
426+
427+
self.assertTrue(torch.allclose(wf_vctk, wf_vctk_sox, rtol=1e-03, atol=1e-03))
428+
429+
def test_pitch(self):
370430
test_dirpath, test_dir = common_utils.create_temp_assets_dir()
371431
test_filepath_100 = os.path.join(test_dirpath, 'assets', "100Hz_44100Hz_16bit_05sec.wav")
372432
test_filepath_440 = os.path.join(test_dirpath, 'assets', "440Hz_44100Hz_16bit_05sec.wav")
@@ -518,6 +578,25 @@ def test_mask_along_axis_iid(self):
518578

519579
_test_torchscript_functional(F.mask_along_axis_iid, specgrams, mask_param, mask_value, axis)
520580

581+
def test_torchscript_gain(self):
582+
tensor = torch.rand((1, 1000))
583+
gainDB = 2.0
584+
585+
_test_torchscript_functional(F.gain, tensor, gainDB)
586+
587+
def test_torchscript_scale_to_interval(self):
588+
tensor = torch.rand((1, 1000))
589+
scaled = 3.5
590+
591+
_test_torchscript_functional(F.gain, tensor, scaled)
592+
593+
def test_torchscript_dither(self):
594+
tensor = torch.rand((1, 1000))
595+
596+
_test_torchscript_functional(F.dither, tensor)
597+
_test_torchscript_functional(F.dither, tensor, "RPDF")
598+
_test_torchscript_functional(F.dither, tensor, "GPDF")
599+
521600

522601
@pytest.mark.parametrize('complex_tensor', [
523602
torch.randn(1, 2, 1025, 400, 2),

torchaudio/datasets/vctk.py

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,18 +21,16 @@ def load_vctk_item(
2121

2222
# Read wav
2323
file_audio = os.path.join(path, folder_audio, speaker_id, fileid + ext_audio)
24+
waveform, sample_rate = torchaudio.load(file_audio)
2425
if downsample:
25-
# Legacy
26-
E = torchaudio.sox_effects.SoxEffectsChain()
27-
E.set_input_file(file_audio)
28-
E.append_effect_to_chain("gain", ["-h"])
29-
E.append_effect_to_chain("channels", [1])
30-
E.append_effect_to_chain("rate", [16000])
31-
E.append_effect_to_chain("gain", ["-rh"])
32-
E.append_effect_to_chain("dither", ["-s"])
33-
waveform, sample_rate = E.sox_build_flow_effects()
34-
else:
35-
waveform, sample_rate = torchaudio.load(file_audio)
26+
# TODO Remove this parameter after deprecation
27+
F = torchaudio.functional
28+
T = torchaudio.transforms
29+
# rate
30+
sample = T.Resample(sample_rate, 16000, resampling_method='sinc_interpolation')
31+
waveform = sample(waveform)
32+
# dither
33+
waveform = F.dither(waveform, noise_shaping=True)
3634

3735
return waveform, sample_rate, utterance, speaker_id, utterance_id
3836

torchaudio/functional.py

Lines changed: 156 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -858,6 +858,162 @@ def compute_deltas(specgram, win_length=5, mode="replicate"):
858858
return output
859859

860860

861+
def gain(waveform, gain_db=1.0):
862+
# type: (Tensor, float) -> Tensor
863+
r"""Apply amplification or attenuation to the whole waveform.
864+
865+
Args:
866+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
867+
gain_db (float) Gain adjustment in decibels (dB) (Default: `1.0`).
868+
869+
Returns:
870+
torch.Tensor: the whole waveform amplified by gain_db.
871+
"""
872+
if (gain_db == 0):
873+
return waveform
874+
875+
ratio = 10 ** (gain_db / 20)
876+
877+
return waveform * ratio
878+
879+
880+
def scale_to_interval(waveform, interval_max=1.0):
881+
# type: (Tensor, float) -> Tensor
882+
r"""Scale the waveform to the interval [-interval_max, interval_max] across all dimensions.
883+
884+
Args:
885+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time).
886+
interval_max (float): The bounds of the interval, where the float indicates
887+
the upper bound and the negative of the float indicates the lower
888+
bound (Default: `1.0`).
889+
Example: interval=1.0 -> [-1.0, 1.0]
890+
891+
Returns:
892+
torch.Tensor: the whole waveform scaled to interval.
893+
"""
894+
abs_max = torch.max(torch.abs(waveform))
895+
ratio = abs_max / interval_max
896+
waveform /= ratio
897+
898+
return waveform
899+
900+
901+
def _add_noise_shaping(dithered_waveform, waveform):
902+
r"""Noise shaping is calculated by error:
903+
error[n] = dithered[n] - original[n]
904+
noise_shaped_waveform[n] = dithered[n] + error[n-1]
905+
"""
906+
wf_shape = waveform.size()
907+
waveform = waveform.reshape(-1, wf_shape[-1])
908+
909+
dithered_shape = dithered_waveform.size()
910+
dithered_waveform = dithered_waveform.reshape(-1, dithered_shape[-1])
911+
912+
error = dithered_waveform - waveform
913+
914+
# add error[n-1] to dithered_waveform[n], so offset the error by 1 index
915+
for index in range(error.size()[0]):
916+
err = error[index]
917+
error_offset = torch.cat((torch.zeros(1), err))
918+
error[index] = error_offset[:waveform.size()[1]]
919+
920+
noise_shaped = dithered_waveform + error
921+
return noise_shaped.reshape(dithered_shape[:-1] + noise_shaped.shape[-1:])
922+
923+
924+
def probability_distribution(waveform, density_function="TPDF"):
925+
# type: (Tensor, str) -> Tensor
926+
r"""Apply a probability distribution function on a waveform.
927+
928+
Triangular probability density function (TPDF) dither noise has a
929+
triangular distribution; values in the center of the range have a higher
930+
probability of occurring.
931+
932+
Rectangular probability density function (RPDF) dither noise has a
933+
uniform distribution; any value in the specified range has the same
934+
probability of occurring.
935+
936+
Gaussian probability density function (GPDF) has a normal distribution.
937+
The relationship of probabilities of results follows a bell-shaped,
938+
or Gaussian curve, typical of dither generated by analog sources.
939+
Args:
940+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
941+
probability_density_function (string): The density function of a
942+
continuous random variable (Default: `TPDF`)
943+
Options: Triangular Probability Density Function - `TPDF`
944+
Rectangular Probability Density Function - `RPDF`
945+
Gaussian Probability Density Function - `GPDF`
946+
Returns:
947+
torch.Tensor: waveform dithered with TPDF
948+
"""
949+
shape = waveform.size()
950+
waveform = waveform.reshape(-1, shape[-1])
951+
952+
channel_size = waveform.size()[0] - 1
953+
time_size = waveform.size()[-1] - 1
954+
955+
random_channel = int(torch.randint(channel_size, [1, ]).item()) if channel_size > 0 else 0
956+
random_time = int(torch.randint(time_size, [1, ]).item()) if time_size > 0 else 0
957+
958+
number_of_bits = 16
959+
up_scaling = 2 ** (number_of_bits - 1) - 2
960+
signal_scaled = waveform * up_scaling
961+
down_scaling = 2 ** (number_of_bits - 1)
962+
963+
signal_scaled_dis = waveform
964+
if (density_function == "RPDF"):
965+
RPDF = waveform[random_channel][random_time] - 0.5
966+
967+
signal_scaled_dis = signal_scaled + RPDF
968+
elif (density_function == "GPDF"):
969+
# TODO Replace by distribution code once
970+
# https://github.com/pytorch/pytorch/issues/29843 is resolved
971+
# gaussian = torch.distributions.normal.Normal(torch.mean(waveform, -1), 1).sample()
972+
973+
num_rand_variables = 6
974+
975+
gaussian = waveform[random_channel][random_time]
976+
for ws in num_rand_variables * [time_size]:
977+
rand_chan = int(torch.randint(channel_size, [1, ]).item())
978+
gaussian += waveform[rand_chan][int(torch.randint(ws, [1, ]).item())]
979+
980+
signal_scaled_dis = signal_scaled + gaussian
981+
else:
982+
TPDF = torch.bartlett_window(time_size + 1)
983+
TPDF = TPDF.repeat((channel_size + 1), 1)
984+
signal_scaled_dis = signal_scaled + TPDF
985+
986+
quantised_signal_scaled = torch.round(signal_scaled_dis)
987+
quantised_signal = quantised_signal_scaled / down_scaling
988+
return quantised_signal.reshape(shape[:-1] + quantised_signal.shape[-1:])
989+
990+
991+
def dither(waveform, density_function="TPDF", noise_shaping=False):
992+
# type: (Tensor, str, bool) -> Tensor
993+
r"""Dither increases the perceived dynamic range of audio stored at a
994+
particular bit-depth by eliminating nonlinear truncation distortion
995+
(i.e. adding minimally perceived noise to mask distortion caused by quantization).
996+
Args:
997+
waveform (torch.Tensor): Tensor of audio of dimension (channel, time)
998+
density_function (string): The density function of a
999+
continuous random variable (Default: `TPDF`)
1000+
Options: Triangular Probability Density Function - `TPDF`
1001+
Rectangular Probability Density Function - `RPDF`
1002+
Gaussian Probability Density Function - `GPDF`
1003+
noise_shaping (boolean): a filtering process that shapes the spectral
1004+
energy of quantisation error (Default: `False`)
1005+
1006+
Returns:
1007+
torch.Tensor: waveform dithered
1008+
"""
1009+
dithered = probability_distribution(waveform, density_function=density_function)
1010+
1011+
if noise_shaping:
1012+
return _add_noise_shaping(dithered, waveform)
1013+
else:
1014+
return dithered
1015+
1016+
8611017
def _compute_nccf(waveform, sample_rate, frame_time, freq_low):
8621018
# type: (Tensor, int, float, int) -> Tensor
8631019
r"""

0 commit comments

Comments
 (0)