Skip to content

Support mask gather in Mosaic compiler #29092

Closed
@xy12181

Description

@xy12181

Description

File: jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py

Repro steps:

1, Change line #661 from

masks.append(q_ids == kv_ids)

to

masks.append(jnp.logical_or(q_ids == kv_ids, kv_ids == 0))

2, Run test

jax/tests/pallas/tpu_splash_attention_kernel_test.py

will show the following Mosaic compilation error:

File "third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", [line 772](https://cs.corp.google.com/piper///depot/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py?l=774&ws=yongx/28802&snapshot=138), in flash_attention_kernel.<locals>.body
    qk = apply_mask_and_soft_cap()
File "third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py", [line 661](https://cs.corp.google.com/piper///depot/google3/third_party/py/jax/experimental/pallas/ops/tpu/splash_attention/splash_attention_kernel.py?l=663&ws=yongx/28802&snapshot=138), in _apply_mask_and_soft_cap
    masks.append(jnp.logical_or(q_ids == kv_ids, kv_ids == 0))
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: failed to legalize operation 'tpu.gather'

The MLIR operation involved:
  %925 = "tpu.gather"(%924) <{dimension = 0 : i32, indices = array<i32: 0, 0, 0, 0, 0, 0, 0, 0>}> : (vector<8x128xi1>) -> vector<8x128xi1>

It looks like the mask kv_ids == 0 causes the mask gather operations, which is not supported in Mosaic.

System info (python version, jaxlib version, accelerator, etc.)

Latest jax. TPU v6e.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions