Skip to content

Conversation

@alexm-redhat
Copy link
Collaborator

@alexm-redhat alexm-redhat commented Sep 18, 2025

This PR removes the need to pad cutlass mla inputs (q_nope and q_pe) to max_heads==128 by pre-padding the buffers that are used by previous operations. Also, the PR improves the way v_proj handles the output by reusing the output buffer earlier inside torch.bmm. For DeepSeekR1 on 8xB200 batch_size==32, decode iteration TPOT performance imporves from 18.87 to 18.25ms, about 3.3%.

Verified correctness with: lm_eval --model vllm --model_args pretrained=deepseek-ai/DeepSeek-R1-0528,tensor_parallel_size=8 --tasks gsm8k --num_fewshot 5 --batch_size auto

Tasks Version Filter n-shot Metric Value Stderr
gsm8k 3 flexible-extract 5 exact_match 0.9538 ± 0.0058
strict-match 5 exact_match 0.9515 ± 0.0059

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces performance optimizations for MLA by pre-padding inputs and optimizing the v_proj output handling. The changes are well-aligned with the performance improvement goals. I have identified two areas for improvement. First, in vllm/v1/attention/backends/mla/common.py, the _v_up_proj method uses an unsafe resize_ on a tensor view, which should be refactored for safety and clarity. Second, vllm/v1/attention/backends/mla/cutlass_mla.py contains temporary, commented-out code that should be removed before merging to improve maintainability. Addressing these points will enhance the code's quality and robustness.

@mergify
Copy link

mergify bot commented Sep 18, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @alexm-redhat.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

Copy link
Member

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good to me, just one nit on types

…tput reshape

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
@mergify mergify bot removed the needs-rebase label Sep 22, 2025
@mgoin mgoin added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 22, 2025
@mgoin mgoin merged commit 0b7bed9 into main Sep 23, 2025
55 of 56 checks passed
@mgoin mgoin deleted the opt_mla_2 branch September 23, 2025 01:20
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…tput handling (vllm-project#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…tput handling (vllm-project#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: charlifu <charlifu@amd.com>
yewentao256 pushed a commit that referenced this pull request Oct 3, 2025
…tput handling (#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: yewentao256 <zhyanwentao@126.com>
gjc0824 pushed a commit to gjc0824/vllm that referenced this pull request Oct 10, 2025
…tput handling (vllm-project#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: gaojc <1055866782@qq.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…tput handling (vllm-project#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…tput handling (vllm-project#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
lywa1998 pushed a commit to lywa1998/vllm that referenced this pull request Oct 20, 2025
…tput handling (vllm-project#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…tput handling (vllm-project#25184)

Signed-off-by: Alexander Matveev <amatveev@redhat.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

documentation Improvements or additions to documentation ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants