Skip to content

Integrate ragged paged attention v2 #8791

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Mar 5, 2025

Conversation

bythew3i
Copy link
Contributor

@bythew3i bythew3i commented Mar 5, 2025

Tested:

python test/test_pallas.py -v -k PallasTest.test_ragged_paged_attention_wrapper

Please Read

This PR adds validation of ragged attn inputs to torch.ops.xla.ragged_paged_attention and expect to run it during runtime. Please move the validation code out if we have to compile something like (or just avoid compiling this).

def ragged_paged_attention_wrapper(...):
    ...
    return torch.ops.xla.ragged_paged_attention(...)


compiled_paged_attention = torch.compile(
        ragged_paged_attention_wrapper, backend="openxla")

Key Features in Ragged Paged Attention V2

  • Support mixed prefill and decode to increase throughput for inference. (eg., 5x speedup compared to padded Muti-Queries Paged Attention implementation for llama-3-8b.)
  • No explicit swapaxes for seq_len and num_head in pre/post kernel. The kernel takes num_head in 2nd minor as it naturally was. We fold swapaxes to strided load/store in the kernel and apply transpose on the fly.
  • No GMM (Grouped Matmul) Metadata required! We calculate the metadata on the fly in the kernel. This can speed up 10%!
  • Increase MXU utilization 8x in GQA by grouping shared q heads for MXU in decode.
  • Minimize recompilation: The only factors can cause recompilation are model specs, max_num_batched_tokens and max_num_seqs in the setting of mixed engine.

Note: this PR does not include tests for Ragged Paged Attention kernel. Because it is already tested in jax-ml/jax#26920 and we will directly import it as source instead of keep duplicated implementations in the future.

@yaochengji yaochengji self-requested a review March 5, 2025 17:49
Copy link
Collaborator

@yaochengji yaochengji left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

@yaochengji yaochengji enabled auto-merge (squash) March 5, 2025 17:50
@yaochengji yaochengji merged commit 5644f44 into pytorch:master Mar 5, 2025
22 of 23 checks passed
pgmoka pushed a commit that referenced this pull request Mar 5, 2025
@zpcore
Copy link
Member

zpcore commented Mar 5, 2025

The test test_ragged_paged_attention_wrapper_with_padding_with_dynamo2 is failing. Can someone help make a fix? Thanks

@yaochengji
Copy link
Collaborator

@zpcore , thanks, it is fixed in #8797

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants