Skip to content

Move lfilter autograd logic to Python #3957

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 39 additions & 24 deletions src/libtorchaudio/lfilter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,10 +100,9 @@ void lfilter_core_generic_loop(
}
}

class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
class DifferentiableIIR{
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs_normalized) {
auto device = waveform.device();
Expand Down Expand Up @@ -139,14 +138,14 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
torch::indexing::Slice(),
torch::indexing::Slice(n_order - 1, torch::indexing::None)});

ctx->save_for_backward({waveform, a_coeffs_normalized, output});
return output;
auto stuff_for_backward = {waveform, a_coeffs_normalized, output};
return output, stuff_for_backward;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
static tuple backward(
auto stuff_for_iir_backward, // technically this should be 3 parameters.
torch::Tensor grad_output) {
auto saved = stuff_for_iir_backward,
auto x = saved[0];
auto a_coeffs_normalized = saved[1];
auto y = saved[2];
Expand All @@ -156,7 +155,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {

auto dx = torch::Tensor();
auto da = torch::Tensor();
auto dy = grad_outputs[0];
auto dy = grad_output;

namespace F = torch::nn::functional;

Expand All @@ -182,10 +181,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
}
};

class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
class DifferentiableFIR {
public:
static torch::Tensor forward(
torch::autograd::AutogradContext* ctx,
const torch::Tensor& waveform,
const torch::Tensor& b_coeffs) {
int64_t n_order = b_coeffs.size(1);
Expand All @@ -201,14 +199,14 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
b_coeff_flipped.unsqueeze(1),
F::Conv1dFuncOptions().groups(n_channel));

ctx->save_for_backward({waveform, b_coeffs, output});
return output;
auto stuff_for_backward = {waveform, b_coeffs, output};
return output, stuff_for_backward;
}

static torch::autograd::tensor_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::tensor_list grad_outputs) {
auto saved = ctx->get_saved_variables();
static tuple backward(
auto stuff_for_backward_fir, // technically this should be 3 parmaeters
torch::Tensor grad_output) {
auto saved = stuff_for_backward_fir,
auto x = saved[0];
auto b_coeffs = saved[1];
auto y = saved[2];
Expand All @@ -219,7 +217,7 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {

auto dx = torch::Tensor();
auto db = torch::Tensor();
auto dy = grad_outputs[0];
auto dy = grad_output;

namespace F = torch::nn::functional;

Expand All @@ -245,7 +243,7 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
}
};

torch::Tensor lfilter_core(
torch::Tensor lfilter_core_forward(
const torch::Tensor& waveform,
const torch::Tensor& a_coeffs,
const torch::Tensor& b_coeffs) {
Expand All @@ -261,18 +259,28 @@ torch::Tensor lfilter_core(

TORCH_INTERNAL_ASSERT(n_order > 0);

auto filtered_waveform = DifferentiableFIR::apply(
auto filtered_waveform, stuff_for_backward_fir = DifferentiableFIR::forward(
waveform,
b_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));

auto output = DifferentiableIIR::apply(
auto output, stuff_for_backward_iir = DifferentiableIIR::forward(
filtered_waveform,
a_coeffs /
a_coeffs.index(
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
return output;
return output, stuff_for_backward_fir, stuff_for_backward_iir;
}

torch::Tensor lfilter_core_backward(
auto stuff_for_backward_fir
auto stuff_for_backward_iir
auto grad_output,
) {
// not sure that's really correct, I'm just winging it.
auto out = DifferentiableIIR::backward(stuff_for_backward_iir, grad_output)
return DifferentiableFIR::backward(stuff_for_backward_fir, out)
}

} // namespace
Expand All @@ -288,6 +296,13 @@ TORCH_LIBRARY(torchaudio, m) {
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
}

TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
m.impl("torchaudio::_lfilter", lfilter_core);
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_forward", &lfilter_forward);
}
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
m.def("torchaudio::_lfilter_backward", &lfilter_backward);
}

// TODO need schema of input/output for both forward and backward, e.g.
// m.def(
// "torchaudio::_lfilter_forward(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor + <the stuff needed for backward here>);
19 changes: 18 additions & 1 deletion src/torchaudio/functional/filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -990,9 +990,26 @@ def _lfilter_core(
output = padded_output_waveform[:, :, n_order - 1 :]
return output

class _LfilterInCppAndAutogradInPython(torch.autograd.Function):
# This class calls the C++ implementation of lfilter for both forward and
# backward, while handling the autograd logic of saving the relevant context
# tensors in Python.
# This requires updating the C++ ops to remove any call to
# `save_for_backward(stuff_for_backward)`. We now need the op to return
# `stuff_for_backward` all the way back to Python.

def forward(self, waveform, a_coeffs, b_coeffs):
output, stuff_for_backward_fir, stuff_for_backward_iir = torch.ops.torchaudio._lfilter_forward(waveform, a_coeffs, b_coeffs)
ctx.save_for_backward(stuff_for_backward_fir)
ctx.save_for_backward(stuff_for_backward_iir)
return output

def backward(self, ctx, grad_outputs):
return torch.ops.torchaudio._lfilter_backward(*ctx.saved_tensors, grad_outputs[0])


if _IS_TORCHAUDIO_EXT_AVAILABLE:
_lfilter = torch.ops.torchaudio._lfilter
_lfilter = _LfilterInCppAndAutogradInPython()
else:
_lfilter = _lfilter_core

Expand Down
Loading