Skip to content

Add frequency_impulse_response #2879

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -50,3 +50,4 @@ DSP
extend_pitch
oscillator_bank
sinc_impulse_response
frequency_impulse_response
Original file line number Diff line number Diff line change
Expand Up @@ -87,3 +87,7 @@ def test_deemphasis(self):
coeff = 0.9
self.assertTrue(gradcheck(F.deemphasis, (waveform, coeff)))
self.assertTrue(gradgradcheck(F.deemphasis, (waveform, coeff)))

def test_freq_ir(self):
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.frequency_impulse_response, (mags,))
6 changes: 6 additions & 0 deletions test/torchaudio_unittest/prototype/functional/dsp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,3 +35,9 @@ def sinc_ir(cutoff: ArrayLike, window_size: int = 513, high_pass: bool = False):
filt *= -1
filt[..., half] = 1.0 + filt[..., half]
return filt


def freq_ir(magnitudes):
ir = np.fft.fftshift(np.fft.irfft(magnitudes), axes=-1)
window = np.hanning(ir.shape[-1])
return (ir * window).astype(magnitudes.dtype)
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from torchaudio.functional import lfilter
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin

from .dsp_utils import oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np
from .dsp_utils import freq_ir as freq_ir_np, oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np


def _prod(l):
Expand Down Expand Up @@ -470,6 +470,22 @@ def test_preemphasis_deemphasis_roundtrip(self, input_shape, coeff):
deemphasized = F.deemphasis(preemphasized, coeff=coeff)
self.assertEqual(deemphasized, waveform)

def test_freq_ir_warns_negative_values(self):
"""frequency_impulse_response warns negative input value"""
magnitudes = -torch.ones((1, 30), device=self.device, dtype=self.dtype)
with self.assertWarnsRegex(UserWarning, "^.+should not contain negative values.$"):
F.frequency_impulse_response(magnitudes)

@parameterized.expand([((2, 3, 4),), ((1000,),)])
def test_freq_ir_reference(self, shape):
"""frequency_impulse_response produces the same result as reference implementation"""
magnitudes = torch.rand(shape, device=self.device, dtype=self.dtype)

hyp = F.frequency_impulse_response(magnitudes)
ref = freq_ir_np(magnitudes.cpu().numpy())

self.assertEqual(hyp, ref)


class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,7 @@ def test_deemphasis(self):
waveform = torch.rand(3, 2, 100, device=self.device, dtype=self.dtype)
coeff = 0.9
self._assert_consistency(F.deemphasis, (waveform, coeff))

def test_freq_ir(self):
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype)
self._assert_consistency(F.frequency_impulse_response, (mags,))
4 changes: 3 additions & 1 deletion torchaudio/prototype/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank, sinc_impulse_response
from ._dsp import adsr_envelope, extend_pitch, frequency_impulse_response, oscillator_bank, sinc_impulse_response
from .functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis, speed


__all__ = [
"add_noise",
"adsr_envelope",
Expand All @@ -9,6 +10,7 @@
"deemphasis",
"extend_pitch",
"fftconvolve",
"frequency_impulse_response",
"oscillator_bank",
"preemphasis",
"sinc_impulse_response",
Expand Down
20 changes: 20 additions & 0 deletions torchaudio/prototype/functional/_dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,3 +286,23 @@ def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pas
filt = -filt
filt[..., half] = 1.0 + filt[..., half]
return filt


def frequency_impulse_response(magnitudes):
"""Create filter from desired frequency response

Args:
magnitudes: The desired frequency responses. Shape: `(..., num_fft_bins)`

Returns:
Tensor: Impulse response. Shape `(..., 2 * (num_fft_bins - 1))`
"""
if magnitudes.min() < 0.0:
# Negative magnitude does not make sense but allowing so that autograd works
# around 0.
# Should we raise error?
warnings.warn("The input frequency response should not contain negative values.")
ir = torch.fft.fftshift(torch.fft.irfft(magnitudes), dim=-1)
device, dtype = magnitudes.device, magnitudes.dtype
window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does it have to be a hann window specifically?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, this is no different than the standard approach of windowing. Other windows can be used. Hann window is the standard one.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

by standard approach, I mean smoothing the side effect of bound effects.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

got it. in that case, perhaps we can consider adding an argument similar to window_fn for Spectrogram that allows for customizing the window

return ir * window