-
-
Couldn't load subscription status.
- Fork 10.9k
[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel #23647
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Flashinfer] Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel #23647
Conversation
cebc836 to
2e0b199
Compare
|
Btw, shall we try gpt-oss on gb200? AIME + high reasoning effort is quite useful for the accuracy. |
2e0b199 to
8562b40
Compare
8562b40 to
5c5647b
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, let's wait for FlashInfer version to land
|
@ProExpertProg Thanks for the review. #24086 Flashinfer 0.3.0 has been updated to main. |
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Head branch was pushed to by a user without write access
4deafc8 to
d203048
Compare
|
Still suffering from the I got all the tests passed locally when I built the full wheel from source using Incremental Compilation Workflow. If I install vllm from source using @ProExpertProg Do you have any suggestions for resolving this? It is weird that I found other PRs could pass the tests on CI. Thanks! |
|
What happens if you build with |
Could pass all the tests with this PR: |
|
EDIT: I had to upgrade flashinfer. This PR broke |
…n Kernel (vllm-project#23647) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
|
@elvischenv On devices that FlashInfer uses FA2 backend (e.g. SM120), using FP8 query is not supported. Defaulting to FP8 query prevents me from enabling FP8 KV-cache for SM120. I see there are 2 possible solutions for this
Anyway, I'm more than happy to work with you to go with either solutions. I have confirmed locally on my machine that simply setting |
…n Kernel (vllm-project#23647) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
…n Kernel (vllm-project#23647) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
…n Kernel (vllm-project#23647) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…n Kernel (vllm-project#23647) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
Support Flashinfer TRTLLM FP8-qkv BF16/FP16-out Attention Kernel.
After this PR,
Flashinfer + kv_cache_dtype=fp8will always quantize query to fp8 and use TRTLLM attn kernel(support FP8-qkv BF16/FP16/FP8/NVFP4-out).Note: This requires Flashinfer 0.3.0 to land
Test Plan && Test Result
Kernel functional:
tests/kernels/attention/test_flashinfer_trtllm_attention.pyKernel performance:
Decode:
benchmarks/kernels/benchmark_trtllm_decode_attention.pyPrefill:
benchmarks/kernels/benchmark_trtllm_prefill_attention.pykv_cache_dtype=fp8 + unfused/fusedunit test:tests/compile/test_fusion_attn.py::test_attention_quant_patternlm_eval
main:
nvidia/Llama-3.3-70B-Instruct-FP8
nvidia/Llama-3.3-70B-Instruct-FP4
PR:
nvidia/Llama-3.3-70B-Instruct-FP8
nvidia/Llama-3.3-70B-Instruct-FP4
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.