Skip to content

[TPU] XLA fails to fuse embedding lookup / array indexing #20899

Open
@neel04

Description

https://github.com/patrick-kidger/equinox/blob/7ee4ca944d75c33d1403122f7ccf141bc390a55e/equinox/nn/_embedding.py#L100

I'm using equinox, and Internally eqx.nn.Embedding is just naively indexing (as shown in above link). However, this is subpar as XLA is unable to fuse vmap(embed_layer) calls, instead doing hundreds of thousands of dynamic slice updates over the weight array:

image

Zooming in, we see this repetitive block pattern repeated thousands of times:
image

Instead, we can force XLA to fuse by:

- return self.weight[x]
+ return jnp.take(self.weight, x, axis=0)

image

Which fixes the issue and yields a ~25% improvement in throughput.

Here's a simple colab repro that records 2 tensorboard traces; Note that the blocks for naive lookup are too small so one may have to zoom in into the trace.

Why does XLA fail to fuse/parallelize naive indexing compared to jnp.take?
Why is the jaxpr generated by jnp.take containing Pjit but the naive indexing does not?

If those ops are equivalent, surely XLA would be able to optimize them? 🤔

Tasks

Preview Give feedback
No tasks being tracked yet.

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions