Skip to content

Commit 96f348e

Browse files
committed
Add frequency_impulse_response
1 parent d8a5a11 commit 96f348e

File tree

7 files changed

+55
-2
lines changed

7 files changed

+55
-2
lines changed

docs/source/prototype.functional.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,3 +50,4 @@ DSP
5050
extend_pitch
5151
oscillator_bank
5252
sinc_impulse_response
53+
frequency_impulse_response

test/torchaudio_unittest/prototype/functional/autograd_test_impl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,3 +87,7 @@ def test_deemphasis(self):
8787
coeff = 0.9
8888
self.assertTrue(gradcheck(F.deemphasis, (waveform, coeff)))
8989
self.assertTrue(gradgradcheck(F.deemphasis, (waveform, coeff)))
90+
91+
def test_freq_ir(self):
92+
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype, requires_grad=True)
93+
assert gradcheck(F.frequency_impulse_response, (mags,))

test/torchaudio_unittest/prototype/functional/dsp_utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,3 +35,9 @@ def sinc_ir(cutoff: ArrayLike, window_size: int = 513, high_pass: bool = False):
3535
filt *= -1
3636
filt[..., half] = 1.0 + filt[..., half]
3737
return filt
38+
39+
40+
def freq_ir(magnitudes):
41+
ir = np.fft.fftshift(np.fft.irfft(magnitudes), axes=-1)
42+
window = np.hanning(ir.shape[-1])
43+
return (ir * window).astype(magnitudes.dtype)

test/torchaudio_unittest/prototype/functional/functional_test_impl.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from torchaudio.functional import lfilter
99
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin
1010

11-
from .dsp_utils import oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
11+
from .dsp_utils import freq_ir as freq_ir_np, oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
1212

1313

1414
def _prod(l):
@@ -470,6 +470,22 @@ def test_preemphasis_deemphasis_roundtrip(self, input_shape, coeff):
470470
deemphasized = F.deemphasis(preemphasized, coeff=coeff)
471471
self.assertEqual(deemphasized, waveform)
472472

473+
def test_freq_ir_warns_negative_values(self):
474+
"""frequency_impulse_response warns negative input value"""
475+
magnitudes = -torch.ones((1, 30), device=self.device, dtype=self.dtype)
476+
with self.assertWarnsRegex(UserWarning, "^.+should not contain negative values.$"):
477+
F.frequency_impulse_response(magnitudes)
478+
479+
@parameterized.expand([((2, 3, 4),), ((1000,),)])
480+
def test_freq_ir_reference(self, shape):
481+
"""frequency_impulse_response produces the same result as reference implementation"""
482+
magnitudes = torch.rand(shape, device=self.device, dtype=self.dtype)
483+
484+
hyp = F.frequency_impulse_response(magnitudes)
485+
ref = freq_ir_np(magnitudes.cpu().numpy())
486+
487+
self.assertEqual(hyp, ref)
488+
473489

474490
class Functional64OnlyTestImpl(TestBaseMixin):
475491
@nested_params(

test/torchaudio_unittest/prototype/functional/torchscript_consistency_test_impl.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,7 @@ def test_deemphasis(self):
9898
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
9999
coeff = 0.9
100100
self._assert_consistency(F.deemphasis, (waveform, coeff))
101+
102+
def test_freq_ir(self):
103+
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype)
104+
self._assert_consistency(F.frequency_impulse_response, (mags,))

torchaudio/prototype/functional/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank, sinc_impulse_response
1+
from ._dsp import adsr_envelope, extend_pitch, frequency_impulse_response, oscillator_bank, sinc_impulse_response
22
from .functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis, speed
33

4+
45
__all__ = [
56
"add_noise",
67
"adsr_envelope",
@@ -9,6 +10,7 @@
910
"deemphasis",
1011
"extend_pitch",
1112
"fftconvolve",
13+
"frequency_impulse_response",
1214
"oscillator_bank",
1315
"preemphasis",
1416
"sinc_impulse_response",

torchaudio/prototype/functional/_dsp.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,3 +286,23 @@ def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pas
286286
filt = -filt
287287
filt[..., half] = 1.0 + filt[..., half]
288288
return filt
289+
290+
291+
def frequency_impulse_response(magnitudes):
292+
"""Create filter from desired frequency response
293+
294+
Args:
295+
magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)`
296+
297+
Returns:
298+
Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))`
299+
"""
300+
if magnitudes.min() < 0.0:
301+
# Negative magnitude does not make sense but allowing so that autograd works
302+
# around 0.
303+
# Should we raise error?
304+
warnings.warn("The input frequency response should not contain negative values.")
305+
ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1)
306+
device, dtype = magnitudes.device, magnitudes.dtype
307+
window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
308+
return ir * window

0 commit comments

Comments
 (0)