Skip to content

Conversation

@elvischenv
Copy link
Contributor

@elvischenv elvischenv commented Sep 10, 2025

Purpose

#23647 always quantizes query if kv cache type is set to FP8, and will use TRTLLM attention kernel. However, there are lots of cases that do not support TRTLLM attention.

Test Plan && Test Result

SM100, kv_cache_dtype=fp8

Using TRTLLM attention (query is quantized).
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.938|±  |0.0108|
|     |       |strict-match    |     5|exact_match|↑  |0.906|±  |0.0131|

SM100, kv_cache_dtype=fp8, VLLM_USE_TRTLLM_ATTENTION=0

VLLM_USE_TRTLLM_ATTENTION is set to 0
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.930|±  |0.0114|
|     |       |strict-match    |     5|exact_match|↑  |0.904|±  |0.0132|

SM120, kv_cache_dtype=fp8, VLLM_ATTENTION_BACKEND=FLASHINFER

(EngineCore_DP0 pid=1252) INFO 09-10 18:44:05 [cuda.py:285] Using FlashInfer backend on V1 engine.
vllm ({'pretrained': 'nvidia/Llama-3.3-70B-Instruct-FP8', 'kv_cache_dtype': 'fp8', 'tensor_parallel_size': 1, 'max_model_len': 2048, 'trust_remote_code': True}), gen_kwargs: (temperature=0.0), limit: 500.0, num_fewshot: 5, batch_size: 200
|Tasks|Version|     Filter     |n-shot|  Metric   |   |Value|   |Stderr|
|-----|------:|----------------|-----:|-----------|---|----:|---|-----:|
|gsm8k|      3|flexible-extract|     5|exact_match|↑  |0.948|±  |0.0099|
|     |       |strict-match    |     5|exact_match|↑  |0.904|±  |0.0132|

Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.
  • (Optional) Release notes update. If your change is user facing, please update the release notes draft in the Google Doc.

@mergify mergify bot added the v1 label Sep 10, 2025
@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch from 3edef5b to 22676bc Compare September 10, 2025 19:13
@elvischenv elvischenv marked this pull request as ready for review September 10, 2025 19:13
@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch 2 times, most recently from cc549c2 to 33c052a Compare September 10, 2025 20:14
@mergify
Copy link

mergify bot commented Sep 11, 2025

This pull request has merge conflicts that must be resolved before it can be
merged. Please rebase the PR, @elvischenv.

https://docs.github.com/en/pull-requests/collaborating-with-pull-requests/working-with-forks/syncing-a-fork

@mergify mergify bot added the needs-rebase label Sep 11, 2025
@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch from 33c052a to d59fe83 Compare September 11, 2025 01:51
@mergify mergify bot removed the needs-rebase label Sep 11, 2025
@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch from d59fe83 to a3055cb Compare September 11, 2025 02:06
@elvischenv
Copy link
Contributor Author

cc @ProExpertProg @mgoin @pavanimajety for review, thanks!

@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch 3 times, most recently from 1c1f8d9 to 5a695fa Compare September 11, 2025 23:10
@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch from 5a695fa to f4a1426 Compare September 15, 2025 08:00
Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good overall, just one nit

Copy link
Collaborator

@ProExpertProg ProExpertProg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, two more notes. Does this mean that when has_sinks is enabled, attention+quant fusion won't work?



@functools.cache
def force_use_trtllm_attention() -> Optional[bool]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You shouldn't read envs in cached functions, please separate into force_use_trtllm_attention (uncached) and _force_use_trtllm_attention (cached, takes env var as input)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed the comments above.

Does this mean that when has_sinks is enabled, attention+quant fusion won't work?

  • If kv=auto, all the things work fine, will use TRTLLM BF16-qkv BF16-out kernel
  • If kv=fp8, by default it will always quantize query and use TRTLLM FP8-qkv
    • In this case, we found some accuracy issues in the FP8-qkv BF16-out sinks kernel. That has been fixed in the TRTLLM upstream and need to be propagated to Flahsinfer and vLLM. For now just raise an error and suggest user to use kv=auto.
    • There is a WAR for using BF16-q FP8-kv BF16-out kernel, by setting VLLM_FLASHINFER_DISABLE_Q_QUANTIZATION, introducing by [flashinfer] [kernel] support for fp8 kv cache for trtllm prefill attention #24197.

Back to attention+quant fusion, AFAIR, only gpt-oss need attn sinks, and we haven't started quantizing the attn output for it. If we want to enable attention+quant fusion for gpt-oss, we have to at least

  1. Quantize the gpt-oss model.
  2. Ensure the FP8-qkv FP8/NVFP4-out attn sinks kernels work and have good accuracy on gpt-oss.

So for now, we need to fix FP8-qkv BF16-out sinks kernel first, and then verify the FP8-qkv FP8/NVFP4-out kernel if we need them.

@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch from f4a1426 to 0f48027 Compare September 15, 2025 17:52
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
@elvischenv elvischenv force-pushed the elvischenv/refactor-trtllm-attn-kernel-selection branch from 0f48027 to 8ec62d7 Compare September 16, 2025 12:20
@elvischenv elvischenv requested a review from mgoin as a code owner September 16, 2025 12:20
@mgoin mgoin added bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed labels Sep 17, 2025
@mgoin mgoin enabled auto-merge (squash) September 17, 2025 21:41
@vllm-bot vllm-bot merged commit e67a79d into vllm-project:main Sep 17, 2025
43 of 45 checks passed
@elvischenv elvischenv deleted the elvischenv/refactor-trtllm-attn-kernel-selection branch September 18, 2025 01:13
debroy-rh pushed a commit to debroy-rh/vllm that referenced this pull request Sep 19, 2025
…llm-project#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
…llm-project#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
charlifu pushed a commit to ROCm/vllm that referenced this pull request Sep 25, 2025
…llm-project#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: charlifu <charlifu@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
…llm-project#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
choprahetarth pushed a commit to Tandemn-Labs/vllm that referenced this pull request Oct 11, 2025
…llm-project#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
…llm-project#24600)

Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Co-authored-by: Michael Goin <mgoin64@gmail.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

bug Something isn't working ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants