-
-
Notifications
You must be signed in to change notification settings - Fork 10.8k
[ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[ROCm][Perf] New design on ROCm AITER MHA backend Implementation #25763
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run You ask your reviewers to trigger select CI tests on top of Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add If you have any questions, please reach out to us on Slack at https://slack.vllm.ai. 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a new, more performant MHA backend implementation for ROCm. The changes include removing redundant KV fetches, introducing phase-aware execution (decode, pure prefill, chunk prefill), reordering inputs for better memory access, and rewriting the Triton kernel for fetching KV cache. The performance improvements demonstrated are significant. However, I have identified a few critical bugs in the implementation that need to be addressed. These include incorrect scaling in a Triton kernel, an invalid tensor view operation that will lead to a runtime error, and a logical error in the batch reordering logic. Addressing these issues is crucial for the correctness and stability of the new backend.
|
Could you please help clarify, which Qwen3 model and datatype are you using? Could you please also append the accuracy results? |
Thanks for the suggestion, just update the PR description with model and accuracy verification. |
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
2734924 to
2f977bc
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@ganyi1996ppo can you contain the changes only to rocm backend?
@HAIAI If we want to keep the changes only in rocm backend, we can hardly rewrite this |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Apologies for the delay. Did another pass; its looking much better thank you! Still need to review reorder_batch_to_split_decodes_prefills_and_extends
Would be good to get @benchislett and @heheda12345 's opinions on ReorderSpec (see #25763 (comment))
| num_decodes: The number of decode requests. | ||
| num_prefills: The number of prefill requests. | ||
| num_decode_tokens: The number of tokens in the decode requests. | ||
| num_prefill_tokens: The number of tokens in the prefill requests. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
update doc-string please
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@LucasWilkinson does this look good to you?
vllm/v1/attention/backends/utils.py
Outdated
| have different ReorderSpec in a single model, high tolerance one will | ||
| be selected.""" | ||
|
|
||
| reorder_batch_threshold: int | None = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: since this is normally used as reorder_spec.reorder_batch_threshold in the repeated reorder is unneeded we could just do reorder_spec.decode_threshold
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I much prefer this, FWIW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Got, reorder_spec.decode_threshold looks good to me too
vllm/v1/attention/backends/utils.py
Outdated
| - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths) | ||
| - VARLEN: Supports variable-length queries (spec decode with mixed lengths) | ||
| If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when | ||
| speculative decoding is enabled.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we could improve this documentation quite a bit; an example of something more visual (this could probably be improved further)
@dataclass
class ReorderSpec:
"""
Defines how the model runner reorders requests within a batch for attention
backends that distinguish between prefill, extend, and decode phases.
Core controls
- decode_threshold: Query lengths ≤ this value are treated as decode.
- split_extend: If True, split prefill into [extend_prefill, pure_prefill].
- query_len_support:
- SINGLE_ONLY: single-token only (no spec decode)
- UNIFORM: uniform multi-token queries (spec decode, equal lengths)
- VARLEN: variable-length queries (spec decode, mixed lengths)
Example input
query_len: [7, 10, 3, 1, 2, 5, 2, 15]
seq_len: [10, 10, 8, 9, 8, 5, 10, 20]
Case 1: decode_threshold=3
query_len: [3, 1, 2, 2, 7, 10, 5, 15]
seq_len: [8, 9, 8, 10, 10, 10, 5, 20]
└──── dec ────┘└──── pre ─────┘
→ Reordered as [decode, prefill].
Case 2: decode_threshold=3, split_extend=True,
query_len: [3, 1, 2, 2, 7, 15, 10, 5]
seq_len: [8, 9, 8, 10, 10, 20, 10, 5]
└────── dec ──────┘└ ext ─┘└pre┘
→ Reordered as [decode, extend_prefill, pure_prefill].
Case 3 (Future/TODO): decode_threshold=3, split_extend=True, query_len_support=UNIFORM
(Move the most common ≤ decode_threshold to the front to form the largest
*uniform* decode region. Here, the uniform decode region is the two q_len=2’s.)
query_len: [2, 2, 3, 1, 7, 15, 10, 5]
seq_len: [8, 10, 8, 9, 10, 20, 10, 5]
└u dec┘ └─dec─┘└──── pre ─────┘
→ Reordered as [uniform-decode(2s), decode, prefill].
"""
decode_threshold: int | None = None
split_extend: bool = False
query_len_support: QueryLenSupport = QueryLenSupport.VARLEN
NOTE we may want to rename query_len_support to decode_ query_len_support
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you so much for your patient comments! Doc-string updated, so as the variable name in ReorderSpec, please take a look
Signed-off-by: ganyi <ygan@amd.com>
vllm/v1/attention/backends/utils.py
Outdated
| Case 2: decode_threshold=3, split_extend=True, | ||
| query_len: [3, 1, 2, 2, 7, 15, 10, 5] | ||
| seq_len: [8, 9, 8, 10, 10, 20, 10, 5] | ||
| └────── dec ──────┘└ ext ─┘└pre┘ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit - the left └ is off by one for ext and pre.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
small update: followed/following up with @benchislett and @heheda12345 offline; feedback is that ReorderSpec may be too complicated; and I tend to agree now that I see it. Theres some talk about separating extends by default using this simple logic #27367 so we can leave reorder_batch_threshold untouched Apologies for the review thrash! will follow up here again soon |
Sure, we would much prefer to leave the |
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
Signed-off-by: ganyi <ygan@amd.com>
|
@LucasWilkinson Have some slight changes of the request role calculation part, please take a look again, thanks ! |
Purpose
The current
AiterFlashAttentionImplfetches K/V every run, which creates unnecessary memory pressure and non-trivial latency—especially with long prompts. This PR:Design and implementation
Phase-aware path:
Input reordering to [decode:cp:pp] ensures tokens are contiguous in memory, improving kernel locality and occupancy. The reorder occurs in both Scheduler's scheduling phase and ModelRunner's state updating phase. We add this
split_prefill_from_chunkto theSchedulerConfigto control this behavior, which will be turned on if bothVLLM_ROCM_USE_AITERandVLLM_ROCM_USE_AITER_MHAare set.Compared with the old one, this solution is more memory efficient and fast, especially on the long prompt scenario. Here is the Performance Measured on Qwen3, Mi308:
Long prompt, short output (
2k prompt, 16 output): ~4.x throughput improvement.Short prompt, long output (
128 prompt, 1k output): ~2.x throughput improvement.Extremely long prompt (
192k prompt, 2k output): ~5.x throughput improvement.Test Plan
acc : lm_eval test for accuracy verification
perf : vllm bench test
Test Result
2k prompt 16 output case:
old impl

new impl

128 prompt 1k output case:
old impl

new impl

acc verification
We test this PR on Qwen3-30B-A3B-FP8 on gsm8k with lm_eval, and here is the result:
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.