Skip to content

Commit 401e7ae

Browse files
authored
Compute deltas (#268)
* compute deltas. * multichannel, and random test. * documentation. * feedback. changing name of window to win_length. * passing padding mode.
1 parent 8273c3f commit 401e7ae

File tree

5 files changed

+136
-1
lines changed

5 files changed

+136
-1
lines changed

test/test_compliance_kaldi.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,5 +319,6 @@ def test_resample_waveform_multi_channel(self):
319319
single_channel_sampled = kaldi.resample_waveform(single_channel, sample_rate, sample_rate // 2)
320320
self.assertTrue(torch.allclose(multi_sound_sampled[i, :], single_channel_sampled, rtol=1e-4))
321321

322+
322323
if __name__ == '__main__':
323324
unittest.main()

test/test_functional.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,32 @@
1818
class TestFunctional(unittest.TestCase):
1919
data_sizes = [(2, 20), (3, 15), (4, 10)]
2020
number_of_trials = 100
21+
specgram = torch.tensor([1., 2., 3., 4.])
22+
23+
def _test_compute_deltas(self, specgram, expected, win_length=3, atol=1e-6, rtol=1e-8):
24+
computed = F.compute_deltas(specgram, win_length=win_length)
25+
self.assertTrue(computed.shape == expected.shape, (computed.shape, expected.shape))
26+
torch.testing.assert_allclose(computed, expected, atol=atol, rtol=rtol)
27+
28+
def test_compute_deltas_onechannel(self):
29+
specgram = self.specgram.unsqueeze(0).unsqueeze(0)
30+
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5]]])
31+
self._test_compute_deltas(specgram, expected)
32+
33+
def test_compute_deltas_twochannel(self):
34+
specgram = self.specgram.repeat(1, 2, 1)
35+
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
36+
[0.5, 1.0, 1.0, 0.5]]])
37+
self._test_compute_deltas(specgram, expected)
38+
39+
def test_compute_deltas_randn(self):
40+
channel = 13
41+
n_mfcc = channel * 3
42+
time = 1021
43+
win_length = 2 * 7 + 1
44+
specgram = torch.randn(channel, n_mfcc, time)
45+
computed = F.compute_deltas(specgram, win_length=win_length)
46+
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
2147

2248
def _compare_estimate(self, sound, estimate, atol=1e-6, rtol=1e-8):
2349
# trim sound for case when constructed signal is shorter than original

test/test_transforms.py

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,9 @@
44

55
import torch
66
import torchaudio
7-
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
87
import torchaudio.transforms as transforms
8+
import torchaudio.functional as F
9+
from torchaudio.common_utils import IMPORT_LIBROSA, IMPORT_SCIPY
910
import unittest
1011
import common_utils
1112

@@ -281,5 +282,37 @@ def test_resample_size(self):
281282
# we expect the downsampled signal to have half as many samples
282283
self.assertTrue(down_sampled.size(-1) == waveform.size(-1) // 2)
283284

285+
def test_compute_deltas(self):
286+
channel = 13
287+
n_mfcc = channel * 3
288+
time = 1021
289+
win_length = 2 * 7 + 1
290+
specgram = torch.randn(channel, n_mfcc, time)
291+
transform = transforms.ComputeDeltas(win_length=win_length)
292+
computed = transform(specgram)
293+
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
294+
295+
def test_compute_deltas_transform_same_as_functional(self, atol=1e-6, rtol=1e-8):
296+
channel = 13
297+
n_mfcc = channel * 3
298+
time = 1021
299+
win_length = 2 * 7 + 1
300+
specgram = torch.randn(channel, n_mfcc, time)
301+
302+
transform = transforms.ComputeDeltas(win_length=win_length)
303+
computed_transform = transform(specgram)
304+
305+
computed_functional = F.compute_deltas(specgram, win_length=win_length)
306+
torch.testing.assert_allclose(computed_functional, computed_transform, atol=atol, rtol=rtol)
307+
308+
def test_compute_deltas_twochannel(self):
309+
specgram = torch.tensor([1., 2., 3., 4.]).repeat(1, 2, 1)
310+
expected = torch.tensor([[[0.5, 1.0, 1.0, 0.5],
311+
[0.5, 1.0, 1.0, 0.5]]])
312+
transform = transforms.ComputeDeltas()
313+
computed = transform(specgram)
314+
self.assertTrue(computed.shape == specgram.shape, (computed.shape, specgram.shape))
315+
316+
284317
if __name__ == '__main__':
285318
unittest.main()

torchaudio/functional.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
"biquad",
2121
]
2222

23+
2324
# TODO: remove this once https://github.com/pytorch/pytorch/issues/21478 gets solved
2425
@torch.jit.ignore
2526
def _stft(
@@ -652,3 +653,50 @@ def lowpass_biquad(waveform, sample_rate, cutoff_freq, Q=0.707):
652653
a1 = -2 * math.cos(w0)
653654
a2 = 1 - alpha
654655
return biquad(waveform, b0, b1, b2, a0, a1, a2)
656+
657+
658+
def compute_deltas(specgram, win_length=5, mode="replicate"):
659+
# type: (Tensor, int, str) -> Tensor
660+
r"""Compute delta coefficients of a tensor, usually a spectrogram:
661+
662+
.. math::
663+
d_t = \frac{\sum_{n=1}^{\text{N}} n (c_{t+n} - c_{t-n})}{2 \sum_{n=1}^{\text{N} n^2}
664+
665+
where :math:`d_t` is the deltas at time :math:`t`,
666+
:math:`c_t` is the spectrogram coeffcients at time :math:`t`,
667+
:math:`N` is (`win_length`-1)//2.
668+
669+
Args:
670+
specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
671+
win_length (int): The window length used for computing delta
672+
mode (str): Mode parameter passed to padding
673+
674+
Returns:
675+
deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
676+
677+
Example
678+
>>> specgram = torch.randn(1, 40, 1000)
679+
>>> delta = compute_deltas(specgram)
680+
>>> delta2 = compute_deltas(delta)
681+
"""
682+
683+
assert win_length >= 3
684+
assert specgram.dim() == 3
685+
assert not specgram.shape[1] % specgram.shape[0]
686+
687+
n = (win_length - 1) // 2
688+
689+
# twice sum of integer squared
690+
denom = n * (n + 1) * (2 * n + 1) / 3
691+
692+
specgram = torch.nn.functional.pad(specgram, (n, n), mode=mode)
693+
694+
kernel = (
695+
torch
696+
.arange(-n, n + 1, 1, device=specgram.device, dtype=specgram.dtype)
697+
.repeat(specgram.shape[1], specgram.shape[0], 1)
698+
)
699+
700+
return torch.nn.functional.conv1d(
701+
specgram, kernel, groups=specgram.shape[1] // specgram.shape[0]
702+
) / denom

torchaudio/transforms.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,30 @@ def forward(self, waveform):
365365
return kaldi.resample_waveform(waveform, self.orig_freq, self.new_freq)
366366

367367
raise ValueError('Invalid resampling method: %s' % (self.resampling_method))
368+
369+
370+
class ComputeDeltas(torch.jit.ScriptModule):
371+
r"""Compute delta coefficients of a tensor, usually a spectrogram.
372+
373+
See `torchaudio.functional.compute_deltas` for more details.
374+
375+
Args:
376+
win_length (int): The window length used for computing delta.
377+
"""
378+
__constants__ = ['win_length']
379+
380+
def __init__(self, win_length=5, mode="replicate"):
381+
super(ComputeDeltas, self).__init__()
382+
self.win_length = win_length
383+
self.mode = torch.jit.Attribute(mode, str)
384+
385+
@torch.jit.script_method
386+
def forward(self, specgram):
387+
r"""
388+
Args:
389+
specgram (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
390+
391+
Returns:
392+
deltas (torch.Tensor): Tensor of audio of dimension (channel, n_mfcc, time)
393+
"""
394+
return F.compute_deltas(specgram, win_length=self.win_length, mode=self.mode)

0 commit comments

Comments
 (0)