|
16 | 16 | from flashinfer.prefill import trtllm_batch_context_with_kv_cache |
17 | 17 | from flashinfer.utils import FP4Tensor |
18 | 18 |
|
19 | | -from vllm import _custom_ops as ops |
20 | 19 | from vllm.attention.backends.abstract import ( |
21 | 20 | AttentionBackend, |
22 | 21 | AttentionImpl, |
@@ -828,6 +827,12 @@ def fused_output_quant_supported(self, quant_key: QuantKey): |
828 | 827 | and quant_key in (kFp8StaticTensorSym, kNvfp4Quant) |
829 | 828 | ) |
830 | 829 |
|
| 830 | + def supports_quant_query_input(self) -> bool: |
| 831 | + if flashinfer_disable_q_quantization(): |
| 832 | + return False |
| 833 | + |
| 834 | + return self.support_trtllm_attn |
| 835 | + |
831 | 836 | def forward( |
832 | 837 | self, |
833 | 838 | layer: torch.nn.Module, |
@@ -859,6 +864,12 @@ def forward( |
859 | 864 | # Profiling run. |
860 | 865 | return output.fill_(0) |
861 | 866 |
|
| 867 | + # Ensure query dtype matches the expected dtype from attention metadata |
| 868 | + assert attn_metadata.q_data_type == query.dtype, ( |
| 869 | + f"Query dtype mismatch: expected {attn_metadata.q_data_type}, " |
| 870 | + f"got {query.dtype}" |
| 871 | + ) |
| 872 | + |
862 | 873 | if self.bmm1_scale is None: |
863 | 874 | self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale |
864 | 875 |
|
@@ -899,15 +910,6 @@ def forward( |
899 | 910 | elif output.dtype == FP4_DTYPE: |
900 | 911 | self.o_sf_scale = layer._o_scale_float |
901 | 912 |
|
902 | | - # Insert FP8 quant for query |
903 | | - if attn_metadata.q_data_type == FP8_DTYPE: |
904 | | - num_tokens, num_heads, head_size = query.shape |
905 | | - query, _ = ops.scaled_fp8_quant( |
906 | | - query.reshape((num_tokens, num_heads * head_size)).contiguous(), |
907 | | - layer._q_scale, |
908 | | - ) |
909 | | - query = query.reshape((num_tokens, num_heads, head_size)) |
910 | | - |
911 | 913 | # IMPORTANT! |
912 | 914 | # NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in |
913 | 915 | # eager-mode PyTorch. Thus, we need to be careful about any CPU overhead |
|
0 commit comments