Skip to content

Differentiable filtering using a cascade of second order IIR filters  #3808

Open
@SuperKogito

Description

@SuperKogito

🚀 The feature

A pytorch differentiable sosfilt() implementation like in https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.sosfilt.html, will allow for filtering data along one dimension using cascaded second-order sections. This should allow for a better support of high order stable filtering.

Motivation, pitch

The current alternative is to convert the cascade of biquads (2nd order IIR filters) to a high order filter and then use https://pytorch.org/audio/main/generated/torchaudio.functional.lfilter.html to apply the filter. Unfortunately this only works to a certain order (order<6). The following code illustrates the stability issues faced using lfilter with a high order filter. Hence, an option for a cascaded filtering to maintain stability would be of great advantage.

import torch
import scipy.signal as signal
import matplotlib.pyplot as plt
import torch.nn.functional as F
from torchaudio.functional import lfilter

hardware = "cpu"
device = torch.device(hardware)
eps = 1e-8
    
def coeff_product(polynomials):
    n = polynomials.shape[0]
    if n == 1:
        return polynomials

    c1 = coeff_product(polynomials[n // 2 :])
    c2 = coeff_product(polynomials[: n // 2])
    if c1.shape[1] > c2.shape[1]:
        c1, c2 = c2, c1
    weight = c1.unsqueeze(1).flip(2)
    prod = F.conv1d(
        c2.unsqueeze(0),
        weight,
        padding=weight.shape[2] - 1,
        groups=c2.shape[0],
    ).squeeze(0)
    return prod

if __name__ == "__main__":
    for order in range(2, 12, 2):
        # Print the poles, zeros, and gain
        b, a = signal.ellip(order, 0.009, 80, 0.05, output='ba')
        sos = signal.ellip(order, 0.009, 80, 0.05, output='sos')
        zeros, poles, gain = signal.sos2zpk(sos)
        
        print("-" * 52)
        print("Zeroes : ", zeros)        
        print("Poles  : ", poles)        
        print("-" * 52)
        
        print("sos : ", sos)
        print("-" * 52)        
        
        print("b : ", b)
        print("a : ", a)
        print("-" * 52)
        # init var 
        fs = 500
        eps = 1e-8
        dirac = torch.tensor(signal.unit_impulse(fs), dtype=torch.float32)
        # PYTORCH IMPLEMENTATION
        # prepare coeffs 
        torch_sos = torch.tensor(sos, dtype=torch.float32)
        torch_a = torch_sos[:, 3:]
        torch_b = torch_sos[:, :3]
        high_order_a = coeff_product(torch_a)
        high_order_b = coeff_product(torch_b)
        
        print("sos : ", torch_sos)
        print("-" * 52)        
        
        print("torch_b : ", torch_b)
        print("torch_a : ", torch_a)
        print("-" * 52)
        
        print("high_order_b : ", high_order_b)
        print("high_order_a : ", high_order_a)
        print("-" * 52)
        
        # compute filter response
        y_torch_ba = lfilter(dirac.unsqueeze(0), high_order_a, high_order_b)
        
        ## SCIPY IMPLEMENTATION
        freq, freq_response = signal.sosfreqz(sos)
        x     = signal.unit_impulse(fs)
        y_tf  = signal.lfilter(high_order_b.squeeze(0).detach().numpy(), high_order_a.squeeze(0).detach().numpy(), x)
        y_sos = signal.sosfilt(sos, x)
        
        # plotting
        plt.figure(figsize=(15, 30))
        plt.subplot(3, 1, 1)
        plt.plot(y_sos, 'g', label='SOS')
        plt.legend(loc='best')
    
        plt.subplot(3, 1, 2)
        plt.plot(y_tf, 'k', label='TF')
        plt.legend(loc='best')
    
        plt.subplot(3, 1, 3)
        plt.plot(y_torch_ba.squeeze(0).detach().numpy(), "r", label="torch")
        plt.legend(loc='best')
        plt.show()

This feature would allow users to apply high order filtering (order>6) within loss functions and training loops.

Alternatives

The current alternative since no filtering based on a cascade of biquads is available are:

Additional context

https://dsp.stackexchange.com/questions/31457/multiple-biquads-vs-higher-order-filtering

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions