-
Notifications
You must be signed in to change notification settings - Fork 2
Description
Array API compatible alternatives to scipy.signal.lfilter / sosfilt
Context
Several ezmsg-sigproc processors (notably ButterworthFilterTransformer and everything that depends on it, including RMSBandPowerTransformer and SquareLawBandPowerTransformer) are blocked from Array API portability by their dependence on scipy.signal.lfilter / scipy.signal.sosfilt. These functions only work with numpy arrays, which means any upstream MLX, CuPy, PyTorch, or JAX arrays get implicitly converted to numpy at the filter step, negating GPU acceleration for the entire pipeline.
This issue tracks the landscape of alternatives and a path forward.
The fundamental problem
IIR filters are inherently sequential: y[n] depends on y[n-1]. The Array API standard has no scan, fold, or sequential control flow primitive, so a pure Array API implementation using a Python loop would be correct but would negate any GPU benefit due to per-sample kernel launch overhead.
Current landscape
| Approach | GPU | IIR | Portable | Status |
|---|---|---|---|---|
scipy.signal.sosfilt |
No | Yes | numpy only | Production |
SciPy with SCIPY_ARRAY_API=1 |
CuPy only | Yes | Experimental | Merged Dec 2024; JAX limited to CPU, no JIT |
cupyx.scipy.signal.sosfilt |
CUDA | Yes | CuPy only | Production; uses novel parallel chunked IIR algorithm |
torchaudio.functional.lfilter |
CUDA | Yes | PyTorch only | Production; differentiable |
JAX jax.lax.scan workaround |
Yes (sequential) | Yes | JAX only | PR #19196 (lfilter/sosfilt) still open |
| MLX | Metal (FIR only via mx.convolve) |
No | MLX only | No scan primitive; feature request closed |
MLX is the worst off
MLX has no lfilter, no scan, and no associative_scan. The maintainers closed the scan feature request, suggesting plain loops or vmap instead. A community parallel scan for Mamba exists but the author notes it is "so slow that it is sometimes actually harmful."
Possible approaches
1. Backend-specific dispatch (pragmatic, near-term)
In FilterTransformer._process, check the array namespace and dispatch:
- numpy ->
scipy.signal.sosfilt(current behavior) - CuPy ->
cupyx.scipy.signal.sosfilt - PyTorch ->
torchaudio.functional.lfilter - JAX ->
jax.lax.scan-based implementation
This doesn't help MLX but unblocks CuPy/PyTorch/JAX users immediately.
2. Associative scan IIR filter (elegant, medium-term)
The key mathematical insight (Raph Levien, 2019): an IIR filter's state update is a linear affine map that composes associatively:
state update per sample: (A, b) where y_out = A @ y_in + b
composition: (A1, b1) . (A2, b2) = (A1 @ A2, A2 @ b1 + b2)
Because composition is associative, the entire IIR recurrence can be computed via parallel prefix scan -- O(log N) depth instead of O(N) sequential steps. For second-order sections, each biquad has a 2x2 state matrix.
This is the same math behind Mamba/S4 state-space models, validated at scale on GPUs. Building blocks already exist:
jax.lax.associative_scan(production quality)torch._higher_order_ops.associative_scan(emerging, works withtorch.compile)keras.ops.associative_scan(multi-backend: JAX, TensorFlow, PyTorch)accelerated-scan-- 761x faster than sequential for matrix chainstorch_parallel_scan-- parallel prefix scan for PyTorch
An associative_scan-based sosfilt would be portable to any backend that provides the primitive. The main blocker is MLX, which has no scan primitive at all.
3. FIR approximation (quick workaround)
Replace the IIR Butterworth with a long FIR filter (truncated impulse response), then use convolution -- which is Array API compatible. Convolution maps to mx.convolve on MLX, xp.fft operations everywhere else.
Tradeoffs:
- Much longer kernel required (hundreds of taps vs. a few SOS coefficients)
- Loses precision for filters with long memory
- Not suitable for all filter types
- But it works on every backend today, including MLX
Recommendation
Short-term: backend dispatch for CuPy/PyTorch/JAX users, keeping scipy for numpy.
Medium-term: associative scan implementation of sosfilt, portable to any backend with a scan primitive. This is the mathematically correct generalization and is proven by the Mamba/S4 ecosystem.
For MLX specifically: either lobby Apple to add associative_scan to MLX, or use the FIR approximation as a stopgap.
Related work
- SciPy PR #21713 -- Array API support for lfilter et al (merged Dec 2024)
- JAX PR #19196 -- Add lfilter/sosfilt (open, not merged)
- MLX scan feature request (closed)
- CuPy parallel IIR algorithm
- Raph Levien -- Parallel IIR filters
- Mamba: Linear-Time Sequence Modeling with Selective State Spaces
- TorchFX -- GPU Audio DSP (DAFx 2025)