Skip to content

Commit b18a2c0

Browse files
adabeytaProExpertProg
authored andcommitted
Move query quantization to attention layer for Flashinfer & Triton. (vllm-project#26534)
Signed-off-by: adabeyta <aabeyta@redhat.com> Signed-off-by: Adrian Abeyta <aabeyta@redhat.com> Co-authored-by: Luka Govedič <ProExpertProg@users.noreply.github.com> Signed-off-by: Alberto Perdomo <aperdomo@redhat.com>
1 parent ef49b65 commit b18a2c0

File tree

6 files changed

+43
-38
lines changed

6 files changed

+43
-38
lines changed

tests/compile/test_fusion_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -421,7 +421,9 @@ def test_attention_quant_pattern(
421421
]
422422
if any(attn_fusion_supported):
423423
# Check quantization ops in the graph before and after fusion
424-
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
424+
# Note: fully_replaced=False because query quant ops remain in graph.
425+
# Only output quant ops are fused into attention.
426+
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)
425427

426428
# access the underlying `AttnFusionPass` on the `LazyInitPass`
427429
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)

vllm/attention/backends/abstract.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,6 @@ class AttentionBackend(ABC):
4141
# makes sure the output tensor is allocated inside the cudagraph.
4242
accept_output_buffer: bool = False
4343

44-
# Whether this backend supports receiving pre-quantized query input.
45-
# If True, the attention layer will handle query quantization instead
46-
# of the backend, allowing torch.compile to fuse quantization with
47-
# previous operations.
48-
# Needs to be worked through for all backends
49-
# https://github.com/vllm-project/vllm/issues/25584
50-
supports_quant_query_input: bool = False
51-
5244
@staticmethod
5345
@abstractmethod
5446
def get_name() -> str:
@@ -199,6 +191,22 @@ def fused_output_quant_supported(self, quant_key: QuantKey):
199191
"""
200192
return False
201193

194+
def supports_quant_query_input(self) -> bool:
195+
"""
196+
Check if this attention implementation supports pre-quantized query input.
197+
198+
When True, the attention layer will quantize queries before passing them
199+
to this backend, allowing torch.compile to fuse the quantization with
200+
previous operations. This is typically supported when using FP8 KV cache
201+
with compatible attention kernels (e.g., TRT-LLM).
202+
TODO add support to more backends:
203+
https://github.com/vllm-project/vllm/issues/25584
204+
205+
Returns:
206+
bool: True if the implementation can accept pre-quantized queries.
207+
"""
208+
return False
209+
202210

203211
class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
204212
@abstractmethod

vllm/attention/layer.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from vllm.platforms import current_platform
3737
from vllm.utils import GiB_bytes, direct_register_custom_op
3838

39+
FP8_DTYPE = current_platform.fp8_dtype()
3940
logger = init_logger(__name__)
4041
USE_XFORMERS_OPS = None
4142

@@ -304,7 +305,7 @@ def __init__(
304305
self.query_quant = None
305306
if (
306307
self.kv_cache_dtype.startswith("fp8")
307-
and self.attn_backend.supports_quant_query_input
308+
and self.impl.supports_quant_query_input()
308309
):
309310
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)
310311

@@ -329,7 +330,6 @@ def forward(
329330
"""
330331
if self.calculate_kv_scales:
331332
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)
332-
333333
output_dtype = query.dtype
334334
if self.query_quant is not None:
335335
# quantizing with a simple torch operation enables
@@ -338,7 +338,10 @@ def forward(
338338
# Otherwise queries are quantized using custom ops
339339
# which causes decoding overheads
340340
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
341-
query, _ = self.query_quant(query, self._q_scale)
341+
342+
# check if query quantization is supported
343+
if self.impl.supports_quant_query_input():
344+
query, _ = self.query_quant(query, self._q_scale)
342345

343346
if self.use_output:
344347
output_shape = output_shape if output_shape is not None else query.shape

vllm/v1/attention/backends/flash_attn.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@
4949

5050
class FlashAttentionBackend(AttentionBackend):
5151
accept_output_buffer: bool = True
52-
supports_quant_query_input: bool = True
5352

5453
@classmethod
5554
def get_supported_dtypes(cls) -> list[torch.dtype]:
@@ -494,6 +493,9 @@ def __init__(
494493
"heads in the layer"
495494
)
496495

496+
def supports_quant_query_input(self) -> bool:
497+
return True
498+
497499
def forward(
498500
self,
499501
layer: torch.nn.Module,

vllm/v1/attention/backends/flashinfer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
1717
from flashinfer.utils import FP4Tensor
1818

19-
from vllm import _custom_ops as ops
2019
from vllm.attention.backends.abstract import (
2120
AttentionBackend,
2221
AttentionImpl,
@@ -828,6 +827,12 @@ def fused_output_quant_supported(self, quant_key: QuantKey):
828827
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
829828
)
830829

830+
def supports_quant_query_input(self) -> bool:
831+
if flashinfer_disable_q_quantization():
832+
return False
833+
834+
return self.support_trtllm_attn
835+
831836
def forward(
832837
self,
833838
layer: torch.nn.Module,
@@ -859,6 +864,12 @@ def forward(
859864
# Profiling run.
860865
return output.fill_(0)
861866

867+
# Ensure query dtype matches the expected dtype from attention metadata
868+
assert attn_metadata.q_data_type == query.dtype, (
869+
f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
870+
f"got {query.dtype}"
871+
)
872+
862873
if self.bmm1_scale is None:
863874
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale
864875

@@ -899,15 +910,6 @@ def forward(
899910
elif output.dtype == FP4_DTYPE:
900911
self.o_sf_scale = layer._o_scale_float
901912

902-
# Insert FP8 quant for query
903-
if attn_metadata.q_data_type == FP8_DTYPE:
904-
num_tokens, num_heads, head_size = query.shape
905-
query, _ = ops.scaled_fp8_quant(
906-
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
907-
layer._q_scale,
908-
)
909-
query = query.reshape((num_tokens, num_heads, head_size))
910-
911913
# IMPORTANT!
912914
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
913915
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead

vllm/v1/attention/backends/triton_attn.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -32,11 +32,6 @@
3232
)
3333
from vllm.v1.kv_cache_interface import AttentionSpec
3434

35-
if current_platform.is_cuda_alike():
36-
from vllm import _custom_ops as ops
37-
elif current_platform.is_xpu():
38-
from vllm._ipex_ops import ipex_ops as ops
39-
4035
logger = init_logger(__name__)
4136

4237

@@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
210205
def fused_output_quant_supported(self, quant_key: QuantKey):
211206
return quant_key == kFp8StaticTensorSym
212207

208+
def supports_quant_query_input(self) -> bool:
209+
return current_platform.is_cuda()
210+
213211
def __init__(
214212
self,
215213
num_heads: int,
@@ -338,19 +336,9 @@ def forward(
338336
if key_cache.dtype != self.fp8_dtype:
339337
key_cache = key_cache.view(self.fp8_dtype)
340338
value_cache = value_cache.view(self.fp8_dtype)
341-
num_tokens, num_heads, head_size = query.shape
342339
assert layer._q_scale_float == 1.0, (
343340
"A non 1.0 q_scale is not currently supported."
344341
)
345-
if current_platform.is_cuda():
346-
# Skip Q quantization on ROCm and XPU, enable this on cuda
347-
# only, since dequantizing back to f32 in the attention kernel
348-
# is not supported.
349-
query, _ = ops.scaled_fp8_quant(
350-
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
351-
layer._q_scale,
352-
)
353-
query = query.reshape((num_tokens, num_heads, head_size))
354342

355343
cu_seqlens_q = attn_metadata.query_start_loc
356344
seqused_k = attn_metadata.seq_lens

0 commit comments

Comments
 (0)