@@ -100,10 +100,9 @@ void lfilter_core_generic_loop(
100
100
}
101
101
}
102
102
103
- class DifferentiableIIR : public torch ::autograd::Function<DifferentiableIIR> {
103
+ class DifferentiableIIR {
104
104
public:
105
105
static torch::Tensor forward (
106
- torch::autograd::AutogradContext* ctx,
107
106
const torch::Tensor& waveform,
108
107
const torch::Tensor& a_coeffs_normalized) {
109
108
auto device = waveform.device ();
@@ -139,14 +138,14 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
139
138
torch::indexing::Slice (),
140
139
torch::indexing::Slice (n_order - 1 , torch::indexing::None)});
141
140
142
- ctx-> save_for_backward ( {waveform, a_coeffs_normalized, output}) ;
143
- return output;
141
+ auto stuff_for_backward = {waveform, a_coeffs_normalized, output};
142
+ return output, stuff_for_backward ;
144
143
}
145
144
146
- static torch::autograd::tensor_list backward (
147
- torch::autograd::AutogradContext* ctx,
148
- torch::autograd::tensor_list grad_outputs ) {
149
- auto saved = ctx-> get_saved_variables ();
145
+ static tuple backward (
146
+ auto stuff_for_iir_backward, // technically this should be 3 parameters.
147
+ torch::Tensor grad_output ) {
148
+ auto saved = stuff_for_iir_backward,
150
149
auto x = saved[0 ];
151
150
auto a_coeffs_normalized = saved[1 ];
152
151
auto y = saved[2 ];
@@ -156,7 +155,7 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
156
155
157
156
auto dx = torch::Tensor ();
158
157
auto da = torch::Tensor ();
159
- auto dy = grad_outputs[ 0 ] ;
158
+ auto dy = grad_output ;
160
159
161
160
namespace F = torch::nn::functional;
162
161
@@ -182,10 +181,9 @@ class DifferentiableIIR : public torch::autograd::Function<DifferentiableIIR> {
182
181
}
183
182
};
184
183
185
- class DifferentiableFIR : public torch ::autograd::Function<DifferentiableFIR> {
184
+ class DifferentiableFIR {
186
185
public:
187
186
static torch::Tensor forward (
188
- torch::autograd::AutogradContext* ctx,
189
187
const torch::Tensor& waveform,
190
188
const torch::Tensor& b_coeffs) {
191
189
int64_t n_order = b_coeffs.size (1 );
@@ -201,14 +199,14 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
201
199
b_coeff_flipped.unsqueeze (1 ),
202
200
F::Conv1dFuncOptions ().groups (n_channel));
203
201
204
- ctx-> save_for_backward ( {waveform, b_coeffs, output}) ;
205
- return output;
202
+ auto stuff_for_backward = {waveform, b_coeffs, output};
203
+ return output, stuff_for_backward ;
206
204
}
207
205
208
- static torch::autograd::tensor_list backward (
209
- torch::autograd::AutogradContext* ctx,
210
- torch::autograd::tensor_list grad_outputs ) {
211
- auto saved = ctx-> get_saved_variables ();
206
+ static tuple backward (
207
+ auto stuff_for_backward_fir, // technically this should be 3 parmaeters
208
+ torch::Tensor grad_output ) {
209
+ auto saved = stuff_for_backward_fir,
212
210
auto x = saved[0 ];
213
211
auto b_coeffs = saved[1 ];
214
212
auto y = saved[2 ];
@@ -219,7 +217,7 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
219
217
220
218
auto dx = torch::Tensor ();
221
219
auto db = torch::Tensor ();
222
- auto dy = grad_outputs[ 0 ] ;
220
+ auto dy = grad_output ;
223
221
224
222
namespace F = torch::nn::functional;
225
223
@@ -245,7 +243,7 @@ class DifferentiableFIR : public torch::autograd::Function<DifferentiableFIR> {
245
243
}
246
244
};
247
245
248
- torch::Tensor lfilter_core (
246
+ torch::Tensor lfilter_core_forward (
249
247
const torch::Tensor& waveform,
250
248
const torch::Tensor& a_coeffs,
251
249
const torch::Tensor& b_coeffs) {
@@ -261,18 +259,28 @@ torch::Tensor lfilter_core(
261
259
262
260
TORCH_INTERNAL_ASSERT (n_order > 0 );
263
261
264
- auto filtered_waveform = DifferentiableFIR::apply (
262
+ auto filtered_waveform, stuff_for_backward_fir = DifferentiableFIR::forward (
265
263
waveform,
266
264
b_coeffs /
267
265
a_coeffs.index (
268
266
{torch::indexing::Slice (), torch::indexing::Slice (0 , 1 )}));
269
267
270
- auto output = DifferentiableIIR::apply (
268
+ auto output, stuff_for_backward_iir = DifferentiableIIR::forward (
271
269
filtered_waveform,
272
270
a_coeffs /
273
271
a_coeffs.index (
274
272
{torch::indexing::Slice (), torch::indexing::Slice (0 , 1 )}));
275
- return output;
273
+ return output, stuff_for_backward_fir, stuff_for_backward_iir;
274
+ }
275
+
276
+ torch::Tensor lfilter_core_backward (
277
+ auto stuff_for_backward_fir
278
+ auto stuff_for_backward_iir
279
+ auto grad_output,
280
+ ) {
281
+ // not sure that's really correct, I'm just winging it.
282
+ auto out = DifferentiableIIR::backward (stuff_for_backward_iir, grad_output)
283
+ return DifferentiableFIR::backward (stuff_for_backward_fir, out)
276
284
}
277
285
278
286
} // namespace
@@ -288,6 +296,13 @@ TORCH_LIBRARY(torchaudio, m) {
288
296
" torchaudio::_lfilter(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor" );
289
297
}
290
298
291
- TORCH_LIBRARY_IMPL (torchaudio, CompositeImplicitAutograd , m) {
292
- m.impl (" torchaudio::_lfilter " , lfilter_core );
299
+ TORCH_LIBRARY_FRAGMENT (torchaudio, m) {
300
+ m.def (" torchaudio::_lfilter_forward " , &lfilter_forward );
293
301
}
302
+ TORCH_LIBRARY_FRAGMENT (torchaudio, m) {
303
+ m.def (" torchaudio::_lfilter_backward" , &lfilter_backward);
304
+ }
305
+
306
+ // TODO need schema of input/output for both forward and backward, e.g.
307
+ // m.def(
308
+ // "torchaudio::_lfilter_forward(Tensor waveform, Tensor a_coeffs, Tensor b_coeffs) -> Tensor + <the stuff needed for backward here>);
0 commit comments