Skip to content

Add sinc_impulse_response op #2875

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ DSP
adsr_envelope
extend_pitch
oscillator_bank
sinc_impulse_response
Original file line number Diff line number Diff line change
Expand Up @@ -62,3 +62,8 @@ def test_extend_pitch(self):

assert gradcheck(F.extend_pitch, (input, num_pitches))
assert gradcheck(F.extend_pitch, (input, pattern))

def test_sinc_ir(self):
cutoff = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.sinc_impulse_response, (cutoff, 513, False))
assert gradcheck(F.sinc_impulse_response, (cutoff, 513, True))
17 changes: 17 additions & 0 deletions test/torchaudio_unittest/prototype/functional/dsp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,20 @@ def oscillator_bank(

waveform = amplitudes * np.sin(phases)
return waveform


def sinc_ir(cutoff: ArrayLike, window_size: int = 513, high_pass: bool = False):
if window_size % 2 == 0:
raise ValueError(f"`window_size` must be odd. Given: {window_size}")
half = window_size // 2
dtype = cutoff.dtype
idx = np.linspace(-half, half, window_size, dtype=dtype)

filt = np.sinc(cutoff[..., None] * idx[None, ...])
filt *= np.hamming(window_size).astype(dtype)[None, ...]
filt /= np.abs(filt.sum(axis=-1, keepdims=True))

if high_pass:
filt *= -1
filt[..., half] = 1.0 + filt[..., half]
return filt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,14 @@
from scipy import signal
from torchaudio_unittest.common_utils import nested_params, TestBaseMixin

from .dsp_utils import oscillator_bank as oscillator_bank_np
from .dsp_utils import oscillator_bank as oscillator_bank_np, sinc_ir as sinc_ir_np


def _prod(l):
r = 1
for p in l:
r *= p
return r


class FunctionalTestImpl(TestBaseMixin):
Expand Down Expand Up @@ -311,6 +318,49 @@ def test_extend_pitch(self):
output = F.extend_pitch(input, pat)
self.assertEqual(output, expected)

@nested_params(
# fmt: off
[(1,), (10,), (2, 5), (3, 5, 7)],
[1, 3, 65, 129, 257, 513, 1025],
[True, False],
# fmt: on
)
def test_sinc_ir_shape(self, input_shape, window_size, high_pass):
"""The shape of sinc_impulse_response is correct"""
numel = _prod(input_shape)
cutoff = torch.linspace(1, numel, numel).reshape(input_shape)
cutoff = cutoff.to(dtype=self.dtype, device=self.device)

filt = F.sinc_impulse_response(cutoff, window_size, high_pass)
assert filt.shape == input_shape + (window_size,)

@nested_params([True, False])
def test_sinc_ir_size(self, high_pass):
"""Increasing window size expand the filter at the ends. Core parts must stay same"""
cutoff = torch.tensor([200, 300, 400, 500, 600, 700])
cutoff = cutoff.to(dtype=self.dtype, device=self.device)

filt_5 = F.sinc_impulse_response(cutoff, 5, high_pass)
filt_3 = F.sinc_impulse_response(cutoff, 3, high_pass)

self.assertEqual(filt_3, filt_5[..., 1:-1])

@nested_params(
# fmt: off
[0, 0.1, 0.5, 0.9, 1.0],
[1, 3, 5, 65, 129, 257, 513, 1025, 2049],
[False, True],
# fmt: on
)
def test_sinc_ir_reference(self, cutoff, window_size, high_pass):
"""sinc_impulse_response produces the same result as reference implementation"""
cutoff = torch.tensor([cutoff], device=self.device, dtype=self.dtype)

hyp = F.sinc_impulse_response(cutoff, window_size, high_pass)
ref = sinc_ir_np(cutoff.cpu().numpy(), window_size, high_pass)

self.assertEqual(hyp, ref)


class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,3 +76,8 @@ def test_extend_pitch(self):
self._assert_consistency(F.extend_pitch, (input, num_pitches))
self._assert_consistency(F.extend_pitch, (input, pattern))
self._assert_consistency(F.extend_pitch, (input, torch.tensor(pattern)))

def test_sinc_ir(self):
cutoff = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype)
self._assert_consistency(F.sinc_impulse_response, (cutoff, 513, False))
self._assert_consistency(F.sinc_impulse_response, (cutoff, 513, True))
3 changes: 2 additions & 1 deletion torchaudio/prototype/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank
from ._dsp import adsr_envelope, extend_pitch, oscillator_bank, sinc_impulse_response
from .functional import add_noise, barkscale_fbanks, convolve, fftconvolve

__all__ = [
Expand All @@ -9,4 +9,5 @@
"extend_pitch",
"fftconvolve",
"oscillator_bank",
"sinc_impulse_response",
]
39 changes: 39 additions & 0 deletions torchaudio/prototype/functional/_dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,3 +247,42 @@ def extend_pitch(
mult = torch.tensor(pattern, dtype=base.dtype, device=base.device)
h_freq = base @ mult.unsqueeze(0)
return h_freq


def sinc_impulse_response(cutoff: torch.Tensor, window_size: int = 513, high_pass: bool = False):
"""Create windowed-sinc impulse response for given cutoff frequencies.

.. devices:: CPU CUDA

.. properties:: Autograd TorchScript

Args:
cutoff (Tensor): Cutoff frequencies for low-pass sinc filter.

window_size (int, optional): Size of the Hamming window to apply. Must be odd.
(Default: 513)

high_pass (bool, optional):
If ``True``, convert the resulting filter to high-pass.
Otherwise low-pass filter is returned. Default: ``False``.

Returns:
Tensor: A series of impulse responses. Shape: `(..., window_size)`.
"""
if window_size % 2 == 0:
raise ValueError(f"`window_size` must be odd. Given: {window_size}")

half = window_size // 2
device, dtype = cutoff.device, cutoff.dtype
idx = torch.linspace(-half, half, window_size, device=device, dtype=dtype)

filt = torch.special.sinc(cutoff.unsqueeze(-1) * idx.unsqueeze(0))
filt = filt * torch.hamming_window(window_size, device=device, dtype=dtype, periodic=False).unsqueeze(0)
filt = filt / filt.sum(dim=-1, keepdim=True).abs()

# High pass IR is obtained by subtracting low_pass IR from delta function.
# https://courses.engr.illinois.edu/ece401/fa2020/slides/lec10.pdf
if high_pass:
filt = -filt
filt[..., half] = 1.0 + filt[..., half]
return filt