-
-
Couldn't load subscription status.
- Fork 10.9k
[NVIDIA] Support Flashinfer TRT-LLM Prefill Attention Kernel #22095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[NVIDIA] Support Flashinfer TRT-LLM Prefill Attention Kernel #22095
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
|
Warning Gemini encountered an error creating the review. You can try again by commenting |
|
This pull request has merge conflicts that must be resolved before it can be |
2a87f39 to
21aede2
Compare
|
Overall looks good to me. Thanks! @pavanimajety Please also help to review this. |
21aede2 to
7237472
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still true? Please update the comment if more head group sizes are supported and change the logic for the head group ratio in use_trtllm_attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Fixed in the latest commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Have the cubins been updated to support both layouts? In that case, we may want to remove the default HND restriction placed for SM100
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think flashinfer still have this constraint. The unit test in flashinfer still tests HND only.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We may have to return false for use_trtllm_attention when window_left is non default value as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in the latest commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the PR, @elvischenv. Left some minor feedback comments.
|
@elvischenv Does this PR have full cuda graph support? |
|
@mgoin Can you please review too? Thanks! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please update Blackwell Test in .buildkite/test-pipeline.yaml to include this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks. Fixed in the latest commit.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome work, this looks good to me. Will try to smoke test when I get access to B200 again
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
7237472 to
1918711
Compare
We have tested with full cuda graph and seem it works. |
| # currently prefill trtllm attention does not support fp8 kv cache | ||
| # trtllm may not support sliding window | ||
| prefill_use_trtllm = (self.global_hyperparameters.window_left == -1 | ||
| and not cache_dtype.startswith("fp8") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: I believe this is already checked in use_trtllm_attention
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you mean not cache_dtype.startswith("fp8")? use_trtllm_attention can be overwritten by VLLM_USE_TRTLLM_ATTENTION=1. With VLLM_USE_TRTLLM_ATTENTION=1, we still cannot use TRTLLM FP8-kv since there is no BF16-q FP8-kv kernel.
This is just a WAR for now. Will update this after the FP8-q FP8-kv kernel is supported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would prefer that we clean these up after we have the Attn+FP8/FP4-Quant fusions. Things will be clearer when that part is done. Thanks!
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Jinzhen Lin <linjinzhen@hotmail.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Noam Gat <noamgat@gmail.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
…oject#22095) Signed-off-by: elvischenv <219235043+elvischenv@users.noreply.github.com>
|
@elvischenv @nvpohanh @pavanimajety do you have any update on support of FP8-q FP8-kv kernel? |
|
@frank-wei FP8-QKV and FP8/FP4-output are already supported. FP8-QKV and BF16/FP16-output will be supported after #23647 is merged |
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
Previously #19825 supported TRTLLM attn kernel for decode code path, this PR is aiming to support the prefill path.
VLLM_USE_TRTLLM_ATTENTIONto controlTest Plan && Test Result
tests/kernels/attention/test_flashinfer_trtllm_attention.pylm_evalbenchmarks/kernels/benchmark_trtllm_prefill_attention.py(Optional) Documentation Update