Skip to content

Port lfilter_core_loop wrapper to python [WIP] #3954

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
195 changes: 11 additions & 184 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,194 +100,21 @@ void lfilter_core_generic_loop(
}
}

class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
auto dtype = waveform.dtype();
int64_t n_batch = waveform.size(0);
int64_t n_channel = waveform.size(1);
int64_t n_sample = waveform.size(2);
int64_t n_order = a_coeffs_normalized.size(1);
int64_t n_sample_padded = n_sample + n_order - 1;

auto a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();

auto options = torch::TensorOptions().dtype(dtype).device(device);
auto padded_output_waveform =
torch::zeros({n_batch, n_channel, n_sample_padded}, options);

if (device.is_cpu()) {
cpu_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
} else if (device.is_cuda()) {
#ifdef USE_CUDA
cuda_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform);
#else
lfilter_core_generic_loop(
waveform, a_coeff_flipped, padded_output_waveform);
#endif
} else {
lfilter_core_generic_loop(
waveform, a_coeff_flipped, padded_output_waveform);
}

auto output = padded_output_waveform.index(
{torch::indexing::Slice(),
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});

ctx->save_for_backward({waveform, a_coeffs_normalized, output});
return output;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto a_coeffs_normalized = saved[1];
auto y = saved[2];

int64_t n_channel = x.size(1);
int64_t n_order = a_coeffs_normalized.size(1);

auto dx = torch::Tensor();
auto da = torch::Tensor();
auto dy = grad_outputs[0];

namespace F = torch::nn::functional;

auto tmp =
DifferentiableIIR::apply(dy.flip(2).contiguous(), a_coeffs_normalized)
.flip(2);

if (x.requires_grad()) {
dx = tmp;
}

if (a_coeffs_normalized.requires_grad()) {
da = -torch::matmul(
tmp.transpose(0, 1).reshape({n_channel, 1, -1}),
F::pad(y, F::PadFuncOptions({n_order - 1, 0}))
.unfold(2, n_order, 1)
.transpose(0, 1)
.reshape({n_channel, -1, n_order}))
.squeeze(1)
.flip(1);
}
return {dx, da};
}
};

class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(1);
int64_t n_channel = b_coeffs.size(0);

namespace F = torch::nn::functional;
auto b_coeff_flipped = b_coeffs.flip(1).contiguous();
auto padded_waveform =
F::pad(waveform, F::PadFuncOptions({n_order - 1, 0}));

auto output = F::conv1d(
padded_waveform,
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));

ctx->save_for_backward({waveform, b_coeffs, output});
return output;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto b_coeffs = saved[1];
auto y = saved[2];

int64_t n_batch = x.size(0);
int64_t n_channel = x.size(1);
int64_t n_order = b_coeffs.size(1);

auto dx = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];

namespace F = torch::nn::functional;

if (b_coeffs.requires_grad()) {
db = F::conv1d(
F::pad(x, F::PadFuncOptions({n_order - 1, 0}))
.view({1, n_batch * n_channel, -1}),
dy.view({n_batch * n_channel, 1, -1}),
F::Conv1dFuncOptions().groups(n_batch * n_channel))
.view({n_batch, n_channel, -1})
.sum(0)
.flip(1);
}

if (x.requires_grad()) {
dx = F::conv1d(
F::pad(dy, F::PadFuncOptions({0, n_order - 1})),
b_coeffs.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));
}

return {dx, db};
}
};

torch::Tensor lfilter_core(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
TORCH_CHECK(waveform.device() == a_coeffs.device());
TORCH_CHECK(b_coeffs.device() == a_coeffs.device());
TORCH_CHECK(a_coeffs.sizes() == b_coeffs.sizes());

TORCH_INTERNAL_ASSERT(waveform.sizes().size() == 3);
TORCH_INTERNAL_ASSERT(a_coeffs.sizes().size() == 2);
TORCH_INTERNAL_ASSERT(a_coeffs.size(0) == waveform.size(1));

int64_t n_order = b_coeffs.size(1);

TORCH_INTERNAL_ASSERT(n_order > 0);

auto filtered_waveform = DifferentiableFIR::apply(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));

auto output = DifferentiableIIR::apply(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
}

} // namespace

// Note: We want to avoid using "catch-all" kernel.
// The following registration should be replaced with CPU specific registration.
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
}

TORCH_LIBRARY(torchaudio, m) {
m.def(
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
"torchaudio::_lfilter_core_loop(Tensor input_signal_windows, Tensor a_coeff_flipped, Tensor(a!) padded_output_waveform) -> ()");
}

TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
m.impl("torchaudio::_lfilter_core_loop", &cpu_lfilter_core_loop);
}

TORCH_LIBRARY_IMPL(torchaudio, CUDA, m) {
m.impl("torchaudio::_lfilter_core_loop", &cuda_lfilter_core_loop);
}

// TORCH_LIBRARY_IMPL(torchaudio, CompositeExplicitAutograd, m) {
// m.impl("torchaudio::_lfilter_core_loop", &lfilter_core_generic_loop);
// }
132 changes: 69 additions & 63 deletions src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch
from torch import Tensor
import torch.nn.functional as F

from torchaudio._extension import _IS_TORCHAUDIO_EXT_AVAILABLE

Expand Down Expand Up @@ -932,69 +933,75 @@ def _lfilter_core_generic_loop(input_signal_windows: Tensor, a_coeffs_flipped: T


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter_core_cpu_loop = torch.ops.torchaudio._lfilter_core_loop
_lfilter_core_loop = torch.ops.torchaudio._lfilter_core_loop
else:
_lfilter_core_cpu_loop = _lfilter_core_generic_loop


def _lfilter_core(
waveform: Tensor,
a_coeffs: Tensor,
b_coeffs: Tensor,
) -> Tensor:

if a_coeffs.size() != b_coeffs.size():
raise ValueError(
"Expected coeffs to be the same size."
f"Found a_coeffs size: {a_coeffs.size()}, b_coeffs size: {b_coeffs.size()}"
)
if waveform.ndim != 3:
raise ValueError(f"Expected waveform to be 3 dimensional. Found: {waveform.ndim}")
if not (waveform.device == a_coeffs.device == b_coeffs.device):
raise ValueError(
"Expected waveform and coeffs to be on the same device."
f"Found: waveform device:{waveform.device}, a_coeffs device: {a_coeffs.device}, "
f"b_coeffs device: {b_coeffs.device}"
)

n_batch, n_channel, n_sample = waveform.size()
n_order = a_coeffs.size(1)
if n_order <= 0:
raise ValueError(f"Expected n_order to be positive. Found: {n_order}")

# Pad the input and create output

padded_waveform = torch.nn.functional.pad(waveform, [n_order - 1, 0])
padded_output_waveform = torch.zeros_like(padded_waveform)

# Set up the coefficients matrix
# Flip coefficients' order
a_coeffs_flipped = a_coeffs.flip(1)
b_coeffs_flipped = b_coeffs.flip(1)

# calculate windowed_input_signal in parallel using convolution
input_signal_windows = torch.nn.functional.conv1d(padded_waveform, b_coeffs_flipped.unsqueeze(1), groups=n_channel)

input_signal_windows.div_(a_coeffs[:, :1])
a_coeffs_flipped.div_(a_coeffs[:, :1])

if (
input_signal_windows.device == torch.device("cpu")
and a_coeffs_flipped.device == torch.device("cpu")
and padded_output_waveform.device == torch.device("cpu")
):
_lfilter_core_cpu_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)
else:
_lfilter_core_generic_loop(input_signal_windows, a_coeffs_flipped, padded_output_waveform)

output = padded_output_waveform[:, :, n_order - 1 :]
return output


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter = torch.ops.torchaudio._lfilter
else:
_lfilter = _lfilter_core
_lfilter_core_loop = _lfilter_core_generic_loop


class DifferentiableFIR(torch.autograd.Function):
@staticmethod
def forward(ctx, waveform, b_coeffs):
n_order = b_coeffs.size(1)
n_channel = b_coeffs.size(0)
b_coeff_flipped = b_coeffs.flip(1).contiguous()
padded_waveform = F.pad(waveform, (n_order - 1, 0))
output = F.conv1d(padded_waveform, b_coeff_flipped.unsqueeze(1), groups=n_channel)
ctx.save_for_backward(waveform, b_coeffs, output)
return output

@staticmethod
def backward(ctx, dy):
x, b_coeffs, y = ctx.saved_tensors
n_batch = x.size(0)
n_channel = x.size(1)
n_order = b_coeffs.size(1)
db = F.conv1d(
F.pad(x, (n_order - 1, 0)).view(n_batch * n_channel, 1, -1),
dy.view(n_batch * n_channel, 1, -1),
groups=n_batch * n_channel
).view(
n_batch, n_channel, -1
).sum(0).flip(1) if b_coeffs.requires_grad else None
dx = F.conv1d(
F.pad(dy, (0, n_order - 1)),
b_coeffs.unsqueeze(1),
groups=n_channel
) if x.requires_grad else None
return (dx, db)

class DifferentiableIIR(torch.autograd.Function):
@staticmethod
def forward(ctx, waveform, a_coeffs_normalized):
n_batch, n_channel, n_sample = waveform.shape
n_order = a_coeffs_normalized.size(1)
n_sample_padded = n_sample + n_order - 1

a_coeff_flipped = a_coeffs_normalized.flip(1).contiguous();
padded_output_waveform = torch.zeros(n_batch, n_channel, n_sample_padded,
device=waveform.device, dtype=waveform.dtype)
_lfilter_core_loop(waveform, a_coeff_flipped, padded_output_waveform)
output = padded_output_waveform[:,:,n_order - 1:]
ctx.save_for_backward(waveform, a_coeff_flipped, output)
return output

@staticmethod
def backward(ctx, dy):
x, a_coeffs_normalized, y = ctx.saved_tensors
n_channel = x.size(1)
n_order = a_coeffs_normalized.size(1)
tmp = DifferentiableIIR.apply(dy.flip(2).contiguous(), a_coeffs_normalized).flip(2)
dx = tmp if x.requires_grad else None
da = -(tmp.transpose(0, 1).reshape(n_channel, 1, -1) @
F.pad(y, (n_order - 1, 0)).unfold(2, n_order, 1).transpose(0,1)
.reshape(n_channel, -1, n_order)
).squeeze(1).flip(1) if a_coeffs_normalized.requires_grad else None
return (dx, da)


def _lfilter(waveform, a_coeffs, b_coeffs):
n_order = b_coeffs.size(1)
filtered_waveform = DifferentiableFIR.apply(waveform, b_coeffs / a_coeffs[:, 0:1])
return DifferentiableIIR.apply(filtered_waveform, a_coeffs / a_coeffs[:, 0:1])


def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool = True, batching: bool = True) -> Tensor:
Expand Down Expand Up @@ -1066,7 +1073,6 @@ def lfilter(waveform: Tensor, a_coeffs: Tensor, b_coeffs: Tensor, clamp: bool =

return output


def lowpass_biquad(waveform: Tensor, sample_rate: int, cutoff_freq: float, Q: float = 0.707) -> Tensor:
r"""Design biquad lowpass filter and perform filtering. Similar to SoX implementation.

Expand Down
Loading