Description
🚀 The feature
A pytorch differentiable sosfilt()
implementation like in https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfilt.html, will allow for filtering data along one dimension using cascaded second-order sections. This should allow for a better support of high order stable filtering.
Motivation, pitch
The current alternative is to convert the cascade of biquads (2nd order IIR filters) to a high order filter and then use https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html to apply the filter. Unfortunately this only works to a certain order (order<6). The following code illustrates the stability issues faced using lfilter
with a high order filter. Hence, an option for a cascaded filtering to maintain stability would be of great advantage.
import torch
import scipy.signal as signal
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchaudio.functional import lfilter
hardware = "cpu"
device = torch.device(hardware)
eps = 1e-8
def coeff_product(polynomials):
n = polynomials.shape[0]
if n == 1:
return polynomials
c1 = coeff_product(polynomials[n // 2 :])
c2 = coeff_product(polynomials[: n // 2])
if c1.shape[1] > c2.shape[1]:
c1, c2 = c2, c1
weight = c1.unsqueeze(1).flip(2)
prod = F.conv1d(
c2.unsqueeze(0),
weight,
padding=weight.shape[2] - 1,
groups=c2.shape[0],
).squeeze(0)
return prod
if __name__ == "__main__":
for order in range(2, 12, 2):
# Print the poles, zeros, and gain
b, a = signal.ellip(order, 0.009, 80, 0.05, output='ba')
sos = signal.ellip(order, 0.009, 80, 0.05, output='sos')
zeros, poles, gain = signal.sos2zpk(sos)
print("-" * 52)
print("Zeroes : ", zeros)
print("Poles : ", poles)
print("-" * 52)
print("sos : ", sos)
print("-" * 52)
print("b : ", b)
print("a : ", a)
print("-" * 52)
# init var
fs = 500
eps = 1e-8
dirac = torch.tensor(signal.unit_impulse(fs), dtype=torch.float32)
# PYTORCH IMPLEMENTATION
# prepare coeffs
torch_sos = torch.tensor(sos, dtype=torch.float32)
torch_a = torch_sos[:, 3:]
torch_b = torch_sos[:, :3]
high_order_a = coeff_product(torch_a)
high_order_b = coeff_product(torch_b)
print("sos : ", torch_sos)
print("-" * 52)
print("torch_b : ", torch_b)
print("torch_a : ", torch_a)
print("-" * 52)
print("high_order_b : ", high_order_b)
print("high_order_a : ", high_order_a)
print("-" * 52)
# compute filter response
y_torch_ba = lfilter(dirac.unsqueeze(0), high_order_a, high_order_b)
## SCIPY IMPLEMENTATION
freq, freq_response = signal.sosfreqz(sos)
x = signal.unit_impulse(fs)
y_tf = signal.lfilter(high_order_b.squeeze(0).detach().numpy(), high_order_a.squeeze(0).detach().numpy(), x)
y_sos = signal.sosfilt(sos, x)
# plotting
plt.figure(figsize=(15, 30))
plt.subplot(3, 1, 1)
plt.plot(y_sos, 'g', label='SOS')
plt.legend(loc='best')
plt.subplot(3, 1, 2)
plt.plot(y_tf, 'k', label='TF')
plt.legend(loc='best')
plt.subplot(3, 1, 3)
plt.plot(y_torch_ba.squeeze(0).detach().numpy(), "r", label="torch")
plt.legend(loc='best')
plt.show()
This feature would allow users to apply high order filtering (order>6) within loss functions and training loops.
Alternatives
The current alternative since no filtering based on a cascade of biquads is available are:
- Use a for loop and feed in the biquad coefficients, one at a time, to https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html. This is unfortunately very slow thus it is very unpractica within a loss function or a training loop.
- Convert the cascade of biquads to a high order filterand use it with https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html . This results in an unstable output as illustrated above which is expected and happens with Scipy too when using
lfilter
instead ofsosfilt()
.
Additional context
https://dsp.stackexchange.com/questions/31457/multiple-biquads-vs-higher-order-filtering