-
-
Notifications
You must be signed in to change notification settings - Fork 11.1k
Move query quantization to attention layer for Flashinfer & Triton. #26534
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
This pull request has merge conflicts that must be resolved before it can be |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the query quantization logic for the Flashinfer and Triton attention backends, moving it from the backend implementation to the higher-level attention layer. This is a positive change for code structure and enables potential compiler fusions. While the changes for the Flashinfer backend appear correct, the removal of a critical assertion for the Triton backend is concerning. This assertion enforced that the query quantization scale must be 1.0, a limitation of the Triton kernel. Its removal could lead to silent correctness issues if not handled in the new quantization logic. I have added a critical review comment to highlight this potential issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
0316471 to
5cc9707
Compare
Signed-off-by: adabeyta <aabeyta@redhat.com>
5c22f29 to
bea52f6
Compare
|
This pull request has merge conflicts that must be resolved before it can be |
…input dynamic for FlashInfer Signed-off-by: adabeyta <aabeyta@redhat.com>
| @@ -157,6 +144,11 @@ def trtllm_prefill_attn_kvfp8_dequant( | |||
| class FlashInferBackend(AttentionBackend): | |||
| accept_output_buffer: bool = True | |||
|
|
|||
| @property | |||
| def supports_quant_query_input(self) -> bool: | |||
| return supports_trtllm_attention( | |||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You may need to rebase or merge main and resolve the import issue
Signed-off-by: Adrian Abeyta <aabeyta@redhat.com>
vllm/attention/layer.py
Outdated
| # which causes decoding overheads | ||
| assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"} | ||
| query, _ = self.query_quant(query, self._q_scale) | ||
| if not hasattr( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think this will work; attention metadata is not set during the profile run when we compile. Instead, we should have a more robust way of checking, likely by calling supports_quant_query_input on the AttentionImpl object
| query = query.reshape((num_tokens, num_heads, head_size)) | ||
| "A non 1.0 q_scale is not currently supported.") | ||
|
|
||
| # Query quantization is now handled in the attention layer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for this comment, just remove
Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com>
|
@adabeyta Any analysis on why we are seeing lower toks/sec with enhanced fusion? Even without a custom kernel, the fact that rope + Quant can be jitted to a triton kernel should give us slightly higher perf, correct? |
Signed-off-by: adabeyta <aabeyta@redhat.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see now we actually lost performance with this; we should make sure we gain and not lose performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wow, those are some insane numbers... good work!
@pavanimajety Updated with new perf numbers. We're seeing better |
|
Great work, thanks for the update! |
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
…llm-project#26534) Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: 0xrushi <6279035+0xrushi@users.noreply.github.com>
Purpose
Implements refactor of quantization to the attention layer for triton and flashinfer, resolves feature request #25584
Test Plan
Spin up server:
Flashinfer:
Triton:
Benchmark:
Accuracy
To ensure there is no accidental accuracy degradation we also run the following for Flashinfer & Triton with kv_cache_dtype in {auto,fp8} both on this PR and on mainline. We also run without enforce_eager=True for the FP8 variants
Test Results
Accuracy on GSM8k
This PR shows no accuracy variation compared to mainline vllm.