Skip to content

[Bugfix] FA2 inf workaround with minimum overhead for MLA #15742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

DefTruth
Copy link
Contributor

@DefTruth DefTruth commented Mar 29, 2025

Reopen #15688 here as a final workaround, fix #15689, #15530

The performance downgrade is due to #15492 for MLA + chunked-prefill (tested on AWQ + PP2 + TP8), where the decode TPS decreased from 390 TPS to 290 TPS. FIXME(DefTruth): To avoid affecting the generation of the Triton kernel, I think we should handle this processing outside the merge_attn_states_kernel. Otherwise, dynamically checking each value for infinity within the kernel might impact the execution efficiency of the generated Triton kernel.

INFO 03-28 13:16:42 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 379.5 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 27.8%, CPU KV cache usage: 0.0%.
INFO 03-28 13:16:47 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 382.2 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.2%, CPU KV cache usage: 0.0%.
INFO 03-28 13:16:52 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 379.0 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.6%, CPU KV cache usage: 0.0%.
INFO 03-28 13:16:57 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 376.1 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.9%, CPU KV cache usage: 0.0%.
INFO 03-28 13:17:02 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 375.3 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 29.3%, CPU KV cache usage: 0.0%.
INFO 03-28 13:17:07 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 375.0 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 29.7%, CPU KV cache usage: 0.0%.
INFO 03-28 13:17:12 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 375.8 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 30.1%, CPU KV cache usage: 0.0%.
# commit: 4d0ec372 285TPS 
INFO 03-28 12:45:07 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 285.2 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 27.9%, CPU KV cache usage: 0.0%.
INFO 03-28 12:45:12 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 286.2 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.2%, CPU KV cache usage: 0.0%.
INFO 03-28 12:45:17 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 287.6 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.5%, CPU KV cache usage: 0.0%.
INFO 03-28 12:45:22 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 286.3 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.7%, CPU KV cache usage: 0.0%.
INFO 03-28 12:45:27 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 281.7 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 29.0%, CPU KV cache usage: 0.0%.
INFO 03-28 12:45:32 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 288.3 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 29.4%, CPU KV cache usage: 0.0%.
INFO 03-28 12:45:37 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 290.1 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 29.6%, CPU KV cache usage: 0.0%.
INFO 03-28 12:45:42 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 290.9 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 29.9%, CPU KV cache usage: 0.0%.
  • w/ this fix, all is well again. (390TPS)
# 390TPS
INFO 03-28 15:24:42 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 391.0 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.1%, CPU KV cache usage: 0.0%.
INFO 03-28 15:24:47 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 388.7 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.5%, CPU KV cache usage: 0.0%.
INFO 03-28 15:24:52 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 385.3 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 28.9%, CPU KV cache usage: 0.0%.
INFO 03-28 15:24:57 [metrics.py:481] Avg prompt throughput: 0.0 tokens/s, Avg generation throughput: 380.8 tokens/s, Running: 32 reqs, Swapped: 0 reqs, Pending: 0 reqs, GPU KV cache usage: 29.3%, CPU KV cache usage: 0.0%.

@LucasWilkinson @robertgshaw2-redhat

Signed-off-by: DefTruth <qiustudent_r@163.com>
Signed-off-by: DefTruth <qiustudent_r@163.com>
Signed-off-by: DefTruth <qiustudent_r@163.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@DefTruth DefTruth changed the title [Bugfix][MLA] Add workaround for FA2 inf MLA with mini overhead [Bugfix][MLA] Add workaround for FA2 inf with mini overhead Mar 29, 2025
@DefTruth DefTruth changed the title [Bugfix][MLA] Add workaround for FA2 inf with mini overhead [Bugfix][MLA] Add workaround for FA2 inf with minimum overhead Mar 29, 2025
@DefTruth DefTruth changed the title [Bugfix][MLA] Add workaround for FA2 inf with minimum overhead [Bugfix][MLA] Workaround for FA2 inf with minimum overhead Mar 29, 2025
@DefTruth DefTruth changed the title [Bugfix][MLA] Workaround for FA2 inf with minimum overhead [Bugfix] FA2 inf workaround with minimum overhead for MLA Mar 29, 2025
@DefTruth
Copy link
Contributor Author

DefTruth commented Mar 30, 2025

@LucasWilkinson Hi~ can you take a look at this fix?

@robertgshaw2-redhat
Copy link
Collaborator

@LucasWilkinson Hi~ can you take a look at this fix?

Lucas has been OOO for PTO since Friday, he will be back tomorrow.

@DefTruth DefTruth marked this pull request as draft April 2, 2025 07:45
@DefTruth DefTruth closed this Apr 2, 2025
@LucasWilkinson
Copy link
Collaborator

@DefTruth apologies for the delays! is there a reason you closed this PR? might good to note that here for posterity

in response to the PR though I think we should see if can optimize the kernel with conditional in-place since this seems more robust than doing something outside the kernel

@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 2, 2025 via email

@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 2, 2025

@LucasWilkinson In issue #15933, I found that context_chunk_max_seq_lens[i], that is, max_seqlen_k, is not 0, but we still got infinity (inf). Therefore, simply checking whether context_chunk_max_seq_lens[i] is 0 or not cannot completely solve this problem.

When the concurrency is greater than or equal to 16, the attn_softmax_lse in _compute_prefill_context has inf values even though the context_chunk_max_seq_lens is greater than 0.

torch.isinf(suffix_lse).any(): False
context_chunk_max_seq_lens[0]:613
torch.isinf(attn_softmax_lse).any(): True
attn_softmax_lse num inf: 15408
attn_softmax_lse num -inf: 0
(RayWorkerWrapper pid=1883420) torch.isinf(suffix_lse).any(): False
(RayWorkerWrapper pid=1883420) context_chunk_max_seq_lens[0]:613
(RayWorkerWrapper pid=1883420) torch.isinf(attn_softmax_lse).any(): True
(RayWorkerWrapper pid=1883420) attn_softmax_lse num inf: 15408
(RayWorkerWrapper pid=1883420) attn_softmax_lse num -inf: 0
(RayWorkerWrapper pid=1883376) context_chunk_max_seq_lens[0]:613
(RayWorkerWrapper pid=1883436) context_chunk_max_seq_lens[0]:613
(RayWorkerWrapper pid=1883866) context_chunk_max_seq_lens[0]:613
(RayWorkerWrapper pid=1883473) context_chunk_max_seq_lens[0]:613
(RayWorkerWrapper pid=1883378) context_chunk_max_seq_lens[0]:613
(RayWorkerWrapper pid=1883397) context_chunk_max_seq_lens[0]:613
max_prefill_seq_len:1199
torch.isinf(suffix_lse).any(): False
context_chunk_max_seq_lens[0]:613
torch.isinf(attn_softmax_lse).any(): True
attn_softmax_lse num inf: 15408
attn_softmax_lse num -inf: 0
max_prefill_seq_len:1199
torch.isinf(suffix_lse).any(): False
context_chunk_max_seq_lens[0]:613
torch.isinf(attn_softmax_lse).any(): True
attn_softmax_lse num inf: 15408
attn_softmax_lse num -inf: 0
max_prefill_seq_len:1199
torch.isinf(suffix_lse).any(): False
context_chunk_max_seq_lens[0]:613
torch.isinf(attn_softmax_lse).any(): True
attn_softmax_lse num inf: 15408
attn_softmax_lse num -inf: 0
max_prefill_seq_len:1199
torch.isinf(suffix_lse).any(): False
context_chunk_max_seq_lens[0]:613
torch.isinf(attn_softmax_lse).any(): True
attn_softmax_lse num inf: 15408
attn_softmax_lse num -inf: 0
max_prefill_seq_len:1199
torch.isinf(suffix_lse).any(): False
context_chunk_max_seq_lens[0]:613
torch.isinf(attn_softmax_lse).any(): True
attn_softmax_lse num inf: 15408
attn_softmax_lse num -inf: 0
max_prefill_seq_len:1199
torch.isinf(suffix_lse).any(): False
context_chunk_max_seq_lens[0]:613
torch.isinf(attn_softmax_lse).any(): True
attn_softmax_lse num inf: 15408
attn_softmax_lse num -inf: 0

The result is really bad for high concurrency

  • launch
export VLLM_USE_V1=0
python3 -m vllm.entrypoints.openai.api_server \
        --model=/workspace/dev/hf_models/DeepSeek-R1 \
        --dtype=auto \
        --block-size 32 \
        --tokenizer-mode=slow \
        --max-model-len 32768 \
        --max-num-batched-tokens 2048 \
        --tensor-parallel-size 8 \
        --pipeline-parallel-size 3 \
        --gpu-memory-utilization 0.90 \
        --max-num-seqs 48 \
        --trust-remote-code \
        --no-enable-prefix-caching \
        --enable-chunked-prefill=True \
        --disable-custom-all-reduce \
        --port 8862
  • for parallel=1/2/4/8/12, i got normal result:
****************************************************************************************************
{'model': '/workspace/dev/hf_models/DeepSeek-R1', 'max_tokens': 128, 'stream': True, 'stream_options': {'include_usage': True}, 'temperature': 0.6, 'top_p': 0.95, 'messages': [{'role': 'user', 'content': 'idea,Java文件位于模块源根之外,因此不会被编译'}]}
----------------------------------------------------------------------------------------------------
<think>
好的,我现在遇到了一个问题,就是IntelliJ IDEA提示我的Java文件位于模块的源根之外,不会被编译。这个问题我需要仔细分析一下,可能涉及到项目结构配置或者模块设置的问题。

首先,我需要回忆一下IDEA中的项目结构。通常,一个项目包含多个模块,每个模块都有自己的源代码根目录,比如src/main/java。如果Java文件不在这些指定的源根目录下,IDEA就不会识别它们为源代码,自然也不会编译。所以,用户的问题应该是他们的Java文件被放在了模块源根之外的位置,导致IDEA无法正确识别和编译。

接下来,我需要考虑可能的原因
****************************************************************************************************
  • for parallel=16, i got bad result:
****************************************************************************************************
{'model': '/workspace/dev/hf_models/DeepSeek-R1', 'max_tokens': 128, 'stream': True, 'stream_options': {'include_usage': True}, 'temperature': 0.6, 'top_p': 0.95, 'messages': [{'role': 'user', 'content': 'idea,Java文件位于模块源根之外,因此不会被编译'}]}
----------------------------------------------------------------------------------------------------
, 2.4.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0
3.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0.0
4.0.0.0.0.0.0.
****************************************************************************************************
  • for parallel=32, i got bad result (~30% reqs return bad results)
****************************************************************************************************
{'model': '/workspace/dev/hf_models/DeepSeek-R1', 'max_tokens': 128, 'stream': True, 'stream_options': {'include_usage': True}, 'temperature': 0.6, 'top_p': 0.95, 'messages': [{'role': 'user', 'content': 'idea,Java文件位于模块源根之外,因此不会被编译'}]}
----------------------------------------------------------------------------------------------------
, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5,
****************************************************************************************************

Regarding the performance issue(for example R1 AWQ), please allow me to consult you. I suspect that the performance we tested under high concurrency before you fixed this problem was inaccurate. When there is high concurrency, since FA2 returns "inf", it will cause the "merge_attn_states" kernel to get a result of "nan", which in turn leads to all subsequent values being "nan". At this time, the results output by vLLM are in chaos. Therefore, I would like to consult you. When vLLM detects that the value is "nan", will it still perform the inference for all model layers, or will it exit immediately and return random results? If it exits immediately and returns random results, then all the performance results of R1 under high concurrency before you fixed this problem were on the high side.

@DefTruth
Copy link
Contributor Author

DefTruth commented Apr 2, 2025

@LucasWilkinson Therefore, if vLLM immediately returns random values when encountering "nan" instead of performing the inference for all the remaining model layers, then the performance of R1 AWQ that I tested previously is actually meaningless. In other words, in this case, it is not these two lines of code that caused AWQ to become slower. Instead, the performance tested originally was incorrect and overestimated.
I sincerely hope that you can answer my questions. many thanks~

@DefTruth DefTruth deleted the vipshop-fa2-inf-workaround branch April 6, 2025 12:23
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Bug]: Performance downgrade 20% for AWQ + MLA + chunked-prefill
3 participants