2222from vllm .model_executor .layers .linear import UnquantizedLinearMethod
2323from vllm .model_executor .layers .quantization .base_config import (
2424 QuantizationConfig )
25+ from vllm .model_executor .layers .quantization .input_quant_fp8 import QuantFP8
2526from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
27+ from vllm .model_executor .layers .quantization .utils .quant_utils import (
28+ GroupShape )
2629from vllm .model_executor .models .vision import get_vit_attn_backend
2730from vllm .platforms import _Backend , current_platform
2831from vllm .utils import GiB_bytes , direct_register_custom_op
@@ -247,6 +250,13 @@ def __init__(
247250 "This may be caused by insufficient memory to allocate "
248251 "kv cache." ) from e
249252
253+ # for attn backends supporting query quantization
254+ self .query_quant = None
255+ if self .kv_cache_dtype .startswith (
256+ "fp8" ) and self .attn_backend .supports_quant_query_input :
257+ self .query_quant = QuantFP8 (static = True ,
258+ group_shape = GroupShape .PER_TENSOR )
259+
250260 def forward (
251261 self ,
252262 query : torch .Tensor ,
@@ -270,11 +280,22 @@ def forward(
270280 attn_metadata = get_forward_context ().attn_metadata
271281 if attn_metadata .enable_kv_scales_calculation :
272282 self .calc_kv_scales (query , key , value )
283+
284+ output_dtype = query .dtype
285+ if self .query_quant is not None :
286+ # quantizing with a simple torch operation enables
287+ # torch.compile to fuse this into previous ops
288+ # which reduces overheads during decoding.
289+ # Otherwise queries are quantized using custom ops
290+ # which causes decoding overheads
291+ assert self .kv_cache_dtype in {"fp8" , "fp8_e4m3" }
292+ query , _ = self .query_quant (query , self ._q_scale )
293+
273294 if self .use_output :
274295 output_shape = (output_shape
275296 if output_shape is not None else query .shape )
276297 output = torch .zeros (output_shape ,
277- dtype = query . dtype ,
298+ dtype = output_dtype ,
278299 device = query .device )
279300 hidden_size = output_shape [- 1 ]
280301 # We skip reshaping query, key and value tensors for the MLA
0 commit comments