Description
I am implementing the tridiagonal matrix algorithm (TDMA) to solve many tridiagonal systems of the same shape in two sweeps (one forward and one backward pass).
The shape of each diagonal is something like (100_000, 100)
, and I vectorize over the leading axis, so this should be reasonably efficient.
In pure NumPy, I would do it like this:
def tdma_naive(a, b, c, d):
"""
Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.
"""
assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape
n = a.shape[-1]
for i in range(1, n):
w = a[..., i] / b[..., i - 1]
b[..., i] += -w * c[..., i - 1]
d[..., i] += -w * d[..., i - 1]
out = np.empty_like(a)
out[..., -1] = d[..., -1] / b[..., -1]
for i in range(n - 2, -1, -1):
out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]
return out
The JAX implementation looks like this:
def tdma_jax_kernel(a, b, c, d):
def compute_primes(last_primes, x):
last_cp, last_dp = last_primes
a, b, c, d = x
denom = 1. / (b - a * last_cp)
cp = c * denom
dp = (d - a * last_dp) * denom
new_primes = (cp, dp)
return new_primes, new_primes
diags = (a.T, b.T, c.T, d.T)
init = jnp.zeros((a.shape[1], a.shape[0]))
_, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)
def backsubstitution(last_x, x):
cp, dp = x
new_x = dp - cp * last_x
return new_x, new_x
_, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))
return sol[::-1].T
I implemented the algorithm in a handful of backends (including a sloppily written CUDA kernel). You can see the results in this Gist:
https://gist.github.com/dionhaefner/a97ef80b77e02b36e4b248bb97541161
The executive summary is that Jax is 2.5x slower than Numba on CPU, and 3x slower than my amateurish CUDA kernel on GPU (but is on par with Numba here).
If I eliminate the tranposes from the Jax implementation and transpose the inputs beforehand, the implementation gains a factor 2 of performance on GPU, so it would be nice if scan
supported scanning over arbitrary axes.
Is this behavior something that is expected, and is there something else I can do to make the Jax implementation more efficient?