Skip to content

Commit 982aa2e

Browse files
committed
Move lfilter autograd logic to Python
1 parent 70caf76 commit 982aa2e

File tree

2 files changed

+57
-25
lines changed

2 files changed

+57
-25
lines changed

src/libtorchaudio/lfilter.cpp

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -100,10 +100,9 @@ void lfilter_core_generic_loop(
100100
}
101101
}
102102

103-
class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
103+
class DifferentiableIIR{
104104
public:
105105
static torch::Tensor forward(
106-
torch::autograd::AutogradContext* ctx,
107106
const torch::Tensor& waveform,
108107
const torch::Tensor& a_coeffs_normalized) {
109108
auto device = waveform.device();
@@ -139,14 +138,14 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
139138
torch::indexing::Slice(),
140139
torch::indexing::Slice(n_order - 1, torch::indexing::None)});
141140

142-
ctx->save_for_backward({waveform, a_coeffs_normalized, output});
143-
return output;
141+
auto stuff_for_backward = {waveform, a_coeffs_normalized, output};
142+
return output, stuff_for_backward;
144143
}
145144

146-
static torch::autograd::tensor_list backward(
147-
torch::autograd::AutogradContext* ctx,
148-
torch::autograd::tensor_list grad_outputs) {
149-
auto saved = ctx->get_saved_variables();
145+
static tuple backward(
146+
auto stuff_for_iir_backward, // technically this should be 3 parameters.
147+
torch::Tensor grad_output) {
148+
auto saved = stuff_for_iir_backward,
150149
auto x = saved[0];
151150
auto a_coeffs_normalized = saved[1];
152151
auto y = saved[2];
@@ -156,7 +155,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
156155

157156
auto dx = torch::Tensor();
158157
auto da = torch::Tensor();
159-
auto dy = grad_outputs[0];
158+
auto dy = grad_output;
160159

161160
namespace F = torch::nn::functional;
162161

@@ -182,10 +181,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
182181
}
183182
};
184183

185-
class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
184+
class DifferentiableFIR {
186185
public:
187186
static torch::Tensor forward(
188-
torch::autograd::AutogradContext* ctx,
189187
const torch::Tensor& waveform,
190188
const torch::Tensor& b_coeffs) {
191189
int64_t n_order = b_coeffs.size(1);
@@ -201,14 +199,14 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
201199
b_coeff_flipped.unsqueeze(1),
202200
F::Conv1dFuncOptions().groups(n_channel));
203201

204-
ctx->save_for_backward({waveform, b_coeffs, output});
205-
return output;
202+
auto stuff_for_backward = {waveform, b_coeffs, output};
203+
return output, stuff_for_backward;
206204
}
207205

208-
static torch::autograd::tensor_list backward(
209-
torch::autograd::AutogradContext* ctx,
210-
torch::autograd::tensor_list grad_outputs) {
211-
auto saved = ctx->get_saved_variables();
206+
static tuple backward(
207+
auto stuff_for_backward_fir, // technically this should be 3 parmaeters
208+
torch::Tensor grad_output) {
209+
auto saved = stuff_for_backward_fir,
212210
auto x = saved[0];
213211
auto b_coeffs = saved[1];
214212
auto y = saved[2];
@@ -219,7 +217,7 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
219217

220218
auto dx = torch::Tensor();
221219
auto db = torch::Tensor();
222-
auto dy = grad_outputs[0];
220+
auto dy = grad_output;
223221

224222
namespace F = torch::nn::functional;
225223

@@ -245,7 +243,7 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
245243
}
246244
};
247245

248-
torch::Tensor lfilter_core(
246+
torch::Tensor lfilter_core_forward(
249247
const torch::Tensor& waveform,
250248
const torch::Tensor& a_coeffs,
251249
const torch::Tensor& b_coeffs) {
@@ -261,18 +259,28 @@ torch::Tensor lfilter_core(
261259

262260
TORCH_INTERNAL_ASSERT(n_order > 0);
263261

264-
auto filtered_waveform = DifferentiableFIR::apply(
262+
auto filtered_waveform, stuff_for_backward_fir = DifferentiableFIR::forward(
265263
waveform,
266264
b_coeffs /
267265
a_coeffs.index(
268266
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
269267

270-
auto output = DifferentiableIIR::apply(
268+
auto output, stuff_for_backward_iir = DifferentiableIIR::forward(
271269
filtered_waveform,
272270
a_coeffs /
273271
a_coeffs.index(
274272
{torch::indexing::Slice(), torch::indexing::Slice(0, 1)}));
275-
return output;
273+
return output, stuff_for_backward_fir, stuff_for_backward_iir;
274+
}
275+
276+
torch::Tensor lfilter_core_backward(
277+
auto stuff_for_backward_fir
278+
auto stuff_for_backward_iir
279+
auto grad_output,
280+
) {
281+
// not sure that's really correct, I'm just winging it.
282+
auto out = DifferentiableIIR::backward(stuff_for_backward_iir, grad_output)
283+
return DifferentiableFIR::backward(stuff_for_backward_fir, out)
276284
}
277285

278286
} // namespace
@@ -288,6 +296,13 @@ TORCH_LIBRARY(torchaudio, m) {
288296
"torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor");
289297
}
290298

291-
TORCH_LIBRARY_IMPL(torchaudio, CompositeImplicitAutograd, m) {
292-
m.impl("torchaudio::_lfilter", lfilter_core);
299+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
300+
m.def("torchaudio::_lfilter_forward", &lfilter_forward);
293301
}
302+
TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
303+
m.def("torchaudio::_lfilter_backward", &lfilter_backward);
304+
}
305+
306+
// TODO need schema of input/output for both forward and backward, e.g.
307+
// m.def(
308+
// "torchaudio::_lfilter_forward(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor + <the stuff needed for backward here>);

src/torchaudio/functional/filtering.py

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,9 +990,26 @@ def _lfilter_core(
990990
output = padded_output_waveform[:, :, n_order - 1 :]
991991
return output
992992

993+
class _LfilterInCppAndAutogradInPython(torch.autograd.Function):
994+
# This class calls the C++ implementation of lfilter for both forward and
995+
# backward, while handling the autograd logic of saving the relevant context
996+
# tensors in Python.
997+
# This requires updating the C++ ops to remove any call to
998+
# `save_for_backward(stuff_for_backward)`. We now need the op to return
999+
# `stuff_for_backward` all the way back to Python.
1000+
1001+
def forward(self, waveform, a_coeffs, b_coeffs):
1002+
output, stuff_for_backward_fir, stuff_for_backward_iir = torch.ops.torchaudio._lfilter_forward(waveform, a_coeffs, b_coeffs)
1003+
ctx.save_for_backward(stuff_for_backward_fir)
1004+
ctx.save_for_backward(stuff_for_backward_iir)
1005+
return output
1006+
1007+
def backward(self, ctx, grad_outputs):
1008+
return torch.ops.torchaudio._lfilter_backward(*ctx.saved_tensors, grad_outputs[0])
1009+
9931010

9941011
if _IS_TORCHAUDIO_EXT_AVAILABLE:
995-
_lfilter = torch.ops.torchaudio._lfilter
1012+
_lfilter = _LfilterInCppAndAutogradInPython()
9961013
else:
9971014
_lfilter = _lfilter_core
9981015

0 commit comments

Comments
 (0)