Skip to content

Add filter_waveform #2928

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/source/prototype.functional.rst
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ DSP
:nosignatures:

adsr_envelope
filter_waveform
extend_pitch
oscillator_bank
sinc_impulse_response
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,3 +91,8 @@ def test_deemphasis(self):
def test_freq_ir(self):
mags = torch.tensor([0, 0.5, 1.0], device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.frequency_impulse_response, (mags,))

def test_filter_waveform(self):
waveform = torch.rand(3, 1, 2, 10, device=self.device, dtype=self.dtype, requires_grad=True)
filters = torch.rand(3, 2, device=self.device, dtype=self.dtype, requires_grad=True)
assert gradcheck(F.filter_waveform, (waveform, filters))
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,110 @@ def test_freq_ir_reference(self, shape):

self.assertEqual(hyp, ref)

@parameterized.expand(
[
# fmt: off
# INPUT: single-dim waveform and 2D filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((15, ), (3, 3)), # filter size (3) is shorter than chunk size (15 // 3 == 5)
((15, ), (3, 5)), # filter size (5) matches than chunk size
((15, ), (3, 7)), # filter size (7) is longer than chunk size
# INPUT: single-dim waveform and 2D filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((15, ), (4, 3)), # filter size (3) is shorter than chunk size (16 // 4 == 4)
((15, ), (4, 4)), # filter size (4) is shorter than chunk size
((15, ), (4, 5)), # filter size (5) is longer than chunk size
# INPUT: multi-dim waveform and 2D filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((7, 2, 15), (3, 3)),
((7, 2, 15), (3, 5)),
((7, 2, 15), (3, 7)),
# INPUT: single-dim waveform and 2D filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((7, 2, 15), (4, 3)),
((7, 2, 15), (4, 4)),
((7, 2, 15), (4, 5)),
# INPUT: multi-dim waveform and multi-dim filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((7, 2, 15), (7, 2, 3, 3)),
((7, 2, 15), (7, 2, 3, 5)),
((7, 2, 15), (7, 2, 3, 7)),
# INPUT: multi-dim waveform and multi-dim filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((7, 2, 15), (7, 2, 4, 3)),
((7, 2, 15), (7, 2, 4, 4)),
((7, 2, 15), (7, 2, 4, 5)),
# INPUT: multi-dim waveform and (broadcast) multi-dim filter
# The number of frames is divisible with the number of filters (15 % 3 == 0),
# thus waveform must be split into chunks without padding
((7, 2, 15), (1, 1, 3, 3)),
((7, 2, 15), (1, 1, 3, 5)),
((7, 2, 15), (1, 1, 3, 7)),
# INPUT: multi-dim waveform and (broadcast) multi-dim filter
# The number of frames is NOT divisible with the number of filters (15 % 4 != 0),
# thus waveform must be padded before padding
((7, 2, 15), (1, 1, 4, 3)),
((7, 2, 15), (1, 1, 4, 4)),
((7, 2, 15), (1, 1, 4, 5)),
# fmt: on
]
)
def test_filter_waveform_shape(self, waveform_shape, filter_shape):
"""filter_waveform returns the waveform with the same number of samples"""
waveform = torch.randn(waveform_shape, dtype=self.dtype, device=self.device)
filters = torch.randn(filter_shape, dtype=self.dtype, device=self.device)

filtered = F.filter_waveform(waveform, filters)

assert filtered.shape == waveform.shape

@nested_params([1, 3, 5], [3, 5, 7, 4, 6, 8])
def test_filter_waveform_delta(self, num_filters, kernel_size):
"""Applying delta kernel preserves the origianl waveform"""
waveform = torch.arange(-10, 10, dtype=self.dtype, device=self.device)
kernel = torch.zeros((num_filters, kernel_size), dtype=self.dtype, device=self.device)
kernel[:, kernel_size // 2] = 1

result = F.filter_waveform(waveform, kernel)
self.assertEqual(waveform, result)

def test_filter_waveform_same(self, kernel_size=5):
"""Applying the same filter returns the original waveform"""
waveform = torch.arange(-10, 10, dtype=self.dtype, device=self.device)
kernel = torch.randn((1, kernel_size), dtype=self.dtype, device=self.device)
kernels = torch.cat([kernel] * 3)

out1 = F.filter_waveform(waveform, kernel)
out2 = F.filter_waveform(waveform, kernels)
self.assertEqual(out1, out2)

def test_filter_waveform_diff(self):
"""Filters are applied from the first to the last"""
kernel_size = 3
waveform = torch.arange(-10, 10, dtype=self.dtype, device=self.device)
kernels = torch.randn((2, kernel_size), dtype=self.dtype, device=self.device)

# use both filters.
mix = F.filter_waveform(waveform, kernels)
# use only one of them
ref1 = F.filter_waveform(waveform[:10], kernels[0:1])
ref2 = F.filter_waveform(waveform[10:], kernels[1:2])

print("mix:", mix)
print("ref1:", ref1)
print("ref2:", ref2)
Comment on lines +584 to +586
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove print statements

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why? It is useful to have the print statements when test fails.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I thought it was accidentally leftover from debugging. Can we print only if it's failing to avoid extra print clutter and make it more clear the function that is being tested in the print? Otherwise if we start adding generic print/debug statements in tests it can lead to extra clutter.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use pytest as runner, and it by default captures and hides the stdout/stderr. So it's shown only when told so or failed.

# The first filter is effective in the first half
self.assertEqual(mix[:10], ref1[:10])
# The second filter is effective in the second half
self.assertEqual(mix[-9:], ref2[-9:])
# the middle portion is where the two filters affect
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this missing another assert for the middle portion?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no, the middle portion is mixture of two so we can't assert.



class Functional64OnlyTestImpl(TestBaseMixin):
@nested_params(
Expand Down
10 changes: 9 additions & 1 deletion torchaudio/prototype/functional/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,11 @@
from ._dsp import adsr_envelope, extend_pitch, frequency_impulse_response, oscillator_bank, sinc_impulse_response
from ._dsp import (
adsr_envelope,
extend_pitch,
filter_waveform,
frequency_impulse_response,
oscillator_bank,
sinc_impulse_response,
)
from .functional import add_noise, barkscale_fbanks, convolve, deemphasis, fftconvolve, preemphasis, speed


Expand All @@ -10,6 +17,7 @@
"deemphasis",
"extend_pitch",
"fftconvolve",
"filter_waveform",
"frequency_impulse_response",
"oscillator_bank",
"preemphasis",
Expand Down
98 changes: 98 additions & 0 deletions torchaudio/prototype/functional/_dsp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import torch

from .functional import fftconvolve


def oscillator_bank(
frequencies: torch.Tensor,
Expand Down Expand Up @@ -306,3 +308,99 @@ def frequency_impulse_response(magnitudes):
device, dtype = magnitudes.device, magnitudes.dtype
window = torch.hann_window(ir.size(-1), periodic=False, device=device, dtype=dtype).expand_as(ir)
return ir * window


def _overlap_and_add(waveform, stride):
num_frames, frame_size = waveform.shape[-2:]
numel = (num_frames - 1) * stride + frame_size
buffer = torch.zeros(waveform.shape[:-2] + (numel,), device=waveform.device, dtype=waveform.dtype)
for i in range(num_frames):
start = i * stride
end = start + frame_size
buffer[..., start:end] += waveform[..., i, :]
return buffer


def filter_waveform(waveform: torch.Tensor, kernels: torch.Tensor, delay_compensation: int = -1):
"""Applies filters along time axis of the given waveform.

This function applies the given filters along time axis in the following manner:

1. Split the given waveform into chunks. The number of chunks is equal to the number of given filters.
2. Filter each chunk with corresponding filter.
3. Place the filtered chunks at the original indices while adding up the overlapping parts.
4. Crop the resulting waveform so that delay introduced by the filter is removed and its length
matches that of the input waveform.

The following figure illustrates this.

.. image:: https://download.pytorch.org/torchaudio/doc-assets/filter_waveform.png

.. note::

If the number of filters is one, then the operation becomes stationary.
i.e. the same filtering is applied across the time axis.

Args:
waveform (Tensor): Shape `(..., time)`.
kernels (Tensor): Impulse responses.
Valid inputs are 2D tensor with shape `(num_filters, filter_length)` or
`(N+1)`-D tensor with shape `(..., num_filters, filter_length)`, where `N` is
the dimension of waveform.

In case of 2D input, the same set of filters is used across channels and batches.
Otherwise, different sets of filters are applied. In this case, the shape of
the first `N-1` dimensions of filters must match (or be broadcastable to) that of waveform.

delay_compensation (int): Control how the waveform is cropped after full convolution.
If the value is zero or positive, it is interpreted as the length of crop at the
beginning of the waveform. The value cannot be larger than the size of filter kernel.
Otherwise the initial crop is ``filter_size // 2``.
When cropping happens, the waveform is also cropped from the end so that the
length of the resulting waveform matches the input waveform.

Returns:
Tensor: `(..., time)`.
"""
if kernels.ndim not in [2, waveform.ndim + 1]:
raise ValueError(
"`kernels` must be 2 or N+1 dimension where "
f"N is the dimension of waveform. Found: {kernels.ndim} (N={waveform.ndim})"
)

num_filters, filter_size = kernels.shape[-2:]
num_frames = waveform.size(-1)

if delay_compensation > filter_size:
raise ValueError(
"When `delay_compenstation` is provided, it cannot be larger than the size of filters."
f"Found: delay_compensation={delay_compensation}, filter_size={filter_size}"
)

# Transform waveform's time axis into (num_filters x chunk_length) with optional padding
chunk_length = num_frames // num_filters
if num_frames % num_filters > 0:
chunk_length += 1
num_pad = chunk_length * num_filters - num_frames
waveform = torch.nn.functional.pad(waveform, [0, num_pad], "constant", 0)
chunked = waveform.unfold(-1, chunk_length, chunk_length)
assert chunked.numel() >= waveform.numel()

# Broadcast kernels
if waveform.ndim + 1 > kernels.ndim:
expand_shape = waveform.shape[:-1] + kernels.shape
kernels = kernels.expand(expand_shape)

convolved = fftconvolve(chunked, kernels)
restored = _overlap_and_add(convolved, chunk_length)

# Trim in a way that the number of samples are same as input,
# and the filter delay is compensated
if delay_compensation >= 0:
start = delay_compensation
else:
start = filter_size // 2
num_crops = restored.size(-1) - num_frames
end = num_crops - start
result = restored[..., start:-end]
return result