2222from vllm .model_executor .layers .linear import UnquantizedLinearMethod
2323from vllm .model_executor .layers .quantization .base_config import (
2424 QuantizationConfig )
25+ from vllm .model_executor .layers .quantization .input_quant_fp8 import QuantFP8
2526from vllm .model_executor .layers .quantization .kv_cache import BaseKVCacheMethod
27+ from vllm .model_executor .layers .quantization .utils .quant_utils import (
28+ GroupShape )
2629from vllm .model_executor .models .vision import get_vit_attn_backend
2730from vllm .platforms import _Backend , current_platform
2831from vllm .utils import GiB_bytes , direct_register_custom_op
@@ -247,6 +250,13 @@ def __init__(
247250 "This may be caused by insufficient memory to allocate "
248251 "kv cache." ) from e
249252
253+ # for attn backends supporting query quantization
254+ self .query_quant = None
255+ if self .kv_cache_dtype .startswith (
256+ "fp8" ) and self .attn_backend .supports_quant_query_input :
257+ self .query_quant = QuantFP8 (static = True ,
258+ group_shape = GroupShape .PER_TENSOR )
259+
250260 def forward (
251261 self ,
252262 query : torch .Tensor ,
@@ -270,6 +280,15 @@ def forward(
270280 attn_metadata = get_forward_context ().attn_metadata
271281 if attn_metadata .enable_kv_scales_calculation :
272282 self .calc_kv_scales (query , key , value )
283+
284+ if self .query_quant is not None :
285+ # quantizing with a simple torch operation enables
286+ # torch.compile to fuse this into previous ops
287+ # which reduces overheads during decoding.
288+ # Otherwise queries are quantized using custom ops
289+ # which causes decoding overheads
290+ query , _ = self .query_quant .forward_native (query , self ._q_scale )
291+
273292 if self .use_output :
274293 output_shape = (output_shape
275294 if output_shape is not None else query .shape )
@@ -278,22 +297,6 @@ def forward(
278297 device = query .device )
279298 hidden_size = output_shape [- 1 ]
280299
281- if envs .VLLM_FUSE_QUERY_QUANT and self .kv_cache_dtype != "auto" :
282- # quantizing with a simple torch operation enables
283- # torch.compile to fuse this into previous ops
284- # which reduces overheads during decoding.
285- # Otherwise queries are quantized using custom ops
286- # which causes decoding overheads
287- assert self ._q_scale .numel () == 1
288- if self .kv_cache_dtype in ["fp8" , "fp8_e4m3" ]:
289- query = (query / self ._q_scale ).to (torch .float8_e4m3fn )
290- elif self .kv_cache_dtype == "fp8_e5m2" :
291- query = (query / self ._q_scale ).to (torch .float8_e5m2 )
292- else :
293- raise NotImplementedError (
294- "VLLM_FUSE_QUERY_QUANT only supported for fp8_e4m3 "
295- "and fp8_e5m2" )
296-
297300 # We skip reshaping query, key and value tensors for the MLA
298301 # backend since these tensors have different semantics and are
299302 # processed differently.
0 commit comments