Closed
Description
Description
File: jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py
Repro steps:
1, Change line #661 from
masks.append(q_ids == kv_ids)
to
masks.append(jnp.logical_or(q_ids == kv_ids, kv_ids == 0))
2, Run test
jax/tests/pallas/tpu_splash_attention_kernel_test.py
will show the following Mosaic compilation error:
File "third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", [line 772](https://cs.corp.google.com/piper///depot/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py?l=774&ws=yongx/28802&snapshot=138), in flash_attention_kernel.<locals>.body
qk = apply_mask_and_soft_cap()
File "third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", [line 661](https://cs.corp.google.com/piper///depot/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py?l=663&ws=yongx/28802&snapshot=138), in _apply_mask_and_soft_cap
masks.append(jnp.logical_or(q_ids == kv_ids, kv_ids == 0))
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'tpu.gather'
The MLIR operation involved:
%925 = "tpu.gather"(%924) <{dimension = 0 : i32, indices = array<i32: 0, 0, 0, 0, 0, 0, 0, 0>}> : (vector<8x128xi1>) -> vector<8x128xi1>
It looks like the mask kv_ids == 0 causes the mask gather operations, which is not supported in Mosaic.
System info (python version, jaxlib version, accelerator, etc.)
Latest jax. TPU v6e.