-
-
Couldn't load subscription status.
- Fork 10.8k
[Performance] Remove input pads in cutlass_mla and optimize v_proj output handling #25184
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces performance optimizations for MLA by pre-padding inputs and optimizing the v_proj output handling. The changes are well-aligned with the performance improvement goals. I have identified two areas for improvement. First, in vllm/v1/attention/backends/mla/common.py, the _v_up_proj method uses an unsafe resize_ on a tensor view, which should be refactored for safety and clarity. Second, vllm/v1/attention/backends/mla/cutlass_mla.py contains temporary, commented-out code that should be removed before merging to improve maintainability. Addressing these points will enhance the code's quality and robustness.
be624a7 to
5add586
Compare
5d49947 to
678db9b
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
9d254a8 to
53e50fd
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good to me, just one nit on types
…tput reshape Signed-off-by: Alexander Matveev <amatveev@redhat.com>
53e50fd to
8bf1e11
Compare
…tput handling (vllm-project#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com>
…tput handling (vllm-project#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: charlifu <charlifu@amd.com>
…tput handling (#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: yewentao256 <zhyanwentao@126.com>
…tput handling (vllm-project#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: gaojc <1055866782@qq.com>
…tput handling (vllm-project#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…tput handling (vllm-project#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com>
…tput handling (vllm-project#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com>
…tput handling (vllm-project#25184) Signed-off-by: Alexander Matveev <amatveev@redhat.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
This PR removes the need to pad cutlass mla inputs (q_nope and q_pe) to max_heads==128 by pre-padding the buffers that are used by previous operations. Also, the PR improves the way v_proj handles the output by reusing the output buffer earlier inside torch.bmm. For DeepSeekR1 on 8xB200 batch_size==32, decode iteration TPOT performance imporves from 18.87 to 18.25ms, about 3.3%.
Verified correctness with: lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-R1-0528,tensor_parallel_size=8 --tasks gsm8k --num_fewshot 5 --batch_size auto