Skip to content

Array API compatibility for IIR filter #129

@cboulay

Description

@cboulay

Array API compatible alternatives to scipy.signal.lfilter / sosfilt

Context

Several ezmsg-sigproc processors (notably ButterworthFilterTransformer and everything that depends on it, including RMSBandPowerTransformer and SquareLawBandPowerTransformer) are blocked from Array API portability by their dependence on scipy.signal.lfilter / scipy.signal.sosfilt. These functions only work with numpy arrays, which means any upstream MLX, CuPy, PyTorch, or JAX arrays get implicitly converted to numpy at the filter step, negating GPU acceleration for the entire pipeline.

This issue tracks the landscape of alternatives and a path forward.

The fundamental problem

IIR filters are inherently sequential: y[n] depends on y[n-1]. The Array API standard has no scan, fold, or sequential control flow primitive, so a pure Array API implementation using a Python loop would be correct but would negate any GPU benefit due to per-sample kernel launch overhead.

Current landscape

Approach GPU IIR Portable Status
scipy.signal.sosfilt No Yes numpy only Production
SciPy with SCIPY_ARRAY_API=1 CuPy only Yes Experimental Merged Dec 2024; JAX limited to CPU, no JIT
cupyx.scipy.signal.sosfilt CUDA Yes CuPy only Production; uses novel parallel chunked IIR algorithm
torchaudio.functional.lfilter CUDA Yes PyTorch only Production; differentiable
JAX jax.lax.scan workaround Yes (sequential) Yes JAX only PR #19196 (lfilter/sosfilt) still open
MLX Metal (FIR only via mx.convolve) No MLX only No scan primitive; feature request closed

MLX is the worst off

MLX has no lfilter, no scan, and no associative_scan. The maintainers closed the scan feature request, suggesting plain loops or vmap instead. A community parallel scan for Mamba exists but the author notes it is "so slow that it is sometimes actually harmful."

Possible approaches

1. Backend-specific dispatch (pragmatic, near-term)

In FilterTransformer._process, check the array namespace and dispatch:

  • numpy -> scipy.signal.sosfilt (current behavior)
  • CuPy -> cupyx.scipy.signal.sosfilt
  • PyTorch -> torchaudio.functional.lfilter
  • JAX -> jax.lax.scan-based implementation

This doesn't help MLX but unblocks CuPy/PyTorch/JAX users immediately.

2. Associative scan IIR filter (elegant, medium-term)

The key mathematical insight (Raph Levien, 2019): an IIR filter's state update is a linear affine map that composes associatively:

state update per sample: (A, b) where y_out = A @ y_in + b
composition: (A1, b1) . (A2, b2) = (A1 @ A2, A2 @ b1 + b2)

Because composition is associative, the entire IIR recurrence can be computed via parallel prefix scan -- O(log N) depth instead of O(N) sequential steps. For second-order sections, each biquad has a 2x2 state matrix.

This is the same math behind Mamba/S4 state-space models, validated at scale on GPUs. Building blocks already exist:

  • jax.lax.associative_scan (production quality)
  • torch._higher_order_ops.associative_scan (emerging, works with torch.compile)
  • keras.ops.associative_scan (multi-backend: JAX, TensorFlow, PyTorch)
  • accelerated-scan -- 761x faster than sequential for matrix chains
  • torch_parallel_scan -- parallel prefix scan for PyTorch

An associative_scan-based sosfilt would be portable to any backend that provides the primitive. The main blocker is MLX, which has no scan primitive at all.

3. FIR approximation (quick workaround)

Replace the IIR Butterworth with a long FIR filter (truncated impulse response), then use convolution -- which is Array API compatible. Convolution maps to mx.convolve on MLX, xp.fft operations everywhere else.

Tradeoffs:

  • Much longer kernel required (hundreds of taps vs. a few SOS coefficients)
  • Loses precision for filters with long memory
  • Not suitable for all filter types
  • But it works on every backend today, including MLX

Recommendation

Short-term: backend dispatch for CuPy/PyTorch/JAX users, keeping scipy for numpy.

Medium-term: associative scan implementation of sosfilt, portable to any backend with a scan primitive. This is the mathematically correct generalization and is proven by the Mamba/S4 ecosystem.

For MLX specifically: either lobby Apple to add associative_scan to MLX, or use the FIR approximation as a stopgap.

Related work

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions