Skip to content

Jax scans are slower than expected #2491

Open
@dionhaefner

Description

I am implementing the tridiagonal matrix algorithm (TDMA) to solve many tridiagonal systems of the same shape in two sweeps (one forward and one backward pass).

The shape of each diagonal is something like (100_000, 100), and I vectorize over the leading axis, so this should be reasonably efficient.

In pure NumPy, I would do it like this:

def tdma_naive(a, b, c, d):
    """
    Solves many tridiagonal matrix systems with diagonals a, b, c and RHS vectors d.
    """
    assert a.shape == b.shape and a.shape == c.shape and a.shape == d.shape

    n = a.shape[-1]

    for i in range(1, n):
        w = a[..., i] / b[..., i - 1]
        b[..., i] += -w * c[..., i - 1]
        d[..., i] += -w * d[..., i - 1]

    out = np.empty_like(a)
    out[..., -1] = d[..., -1] / b[..., -1]

    for i in range(n - 2, -1, -1):
        out[..., i] = (d[..., i] - c[..., i] * out[..., i + 1]) / b[..., i]

    return out

The JAX implementation looks like this:

def tdma_jax_kernel(a, b, c, d):
    def compute_primes(last_primes, x):
        last_cp, last_dp = last_primes
        a, b, c, d = x

        denom = 1. / (b - a * last_cp)
        cp = c * denom
        dp = (d - a * last_dp) * denom

        new_primes = (cp, dp)
        return new_primes, new_primes

    diags = (a.T, b.T, c.T, d.T)
    init = jnp.zeros((a.shape[1], a.shape[0]))
    _, (cp, dp) = jax.lax.scan(compute_primes, (init, init), diags)

    def backsubstitution(last_x, x):
        cp, dp = x
        new_x = dp - cp * last_x
        return new_x, new_x

    _, sol = jax.lax.scan(backsubstitution, init, (cp[::-1], dp[::-1]))

    return sol[::-1].T

I implemented the algorithm in a handful of backends (including a sloppily written CUDA kernel). You can see the results in this Gist:

https://gist.github.com/dionhaefner/a97ef80b77e02b36e4b248bb97541161

The executive summary is that Jax is 2.5x slower than Numba on CPU, and 3x slower than my amateurish CUDA kernel on GPU (but is on par with Numba here).

If I eliminate the tranposes from the Jax implementation and transpose the inputs beforehand, the implementation gains a factor 2 of performance on GPU, so it would be nice if scan supported scanning over arbitrary axes.

Is this behavior something that is expected, and is there something else I can do to make the Jax implementation more efficient?

Metadata

Labels

NVIDIA GPUIssues specific to NVIDIA GPUsP1 (soon)Assignee is working on this now, among other tasks. (Assignee required)performancemake things lean and fast

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions