[TPU] XLA fails to fuse embedding lookup / array indexing #20899
Description
I'm using equinox
, and Internally eqx.nn.Embedding
is just naively indexing (as shown in above link). However, this is subpar as XLA is unable to fuse vmap(embed_layer)
calls, instead doing hundreds of thousands of dynamic slice updates over the weight
array:
Zooming in, we see this repetitive block pattern repeated thousands of times:
Instead, we can force XLA
to fuse by:
- return self.weight[x]
+ return jnp.take(self.weight, x, axis=0)
Which fixes the issue and yields a ~25% improvement in throughput.
Here's a simple colab repro that records 2 tensorboard traces; Note that the blocks for naive lookup are too small so one may have to zoom in into the trace.
Why does XLA
fail to fuse/parallelize naive indexing compared to jnp.take
?
Why is the jaxpr generated by jnp.take
containing Pjit
but the naive indexing does not?
If those ops are equivalent, surely XLA
would be able to optimize them? 🤔