Skip to content

Commit b30425b

Browse files
committed
rework
Signed-off-by: Jonas Kuebler <kuebj@amazon.com>
1 parent 296050a commit b30425b

File tree

3 files changed

+27
-24
lines changed

3 files changed

+27
-24
lines changed

vllm/attention/backends/abstract.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ class AttentionBackend(ABC):
3434
# makes sure the output tensor is allocated inside the cudagraph.
3535
accept_output_buffer: bool = False
3636

37+
# Whether this backend supports receiving pre-quantized query input.
38+
# If True, the attention layer will handle query quantization instead
39+
# of the backend, allowing torch.compile to fuse quantization with
40+
# previous operations.
41+
supports_quant_query_input: bool = False
42+
3743
@staticmethod
3844
@abstractmethod
3945
def get_name() -> str:

vllm/attention/layer.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig)
25+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
2526
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
28+
GroupShape)
2629
from vllm.model_executor.models.vision import get_vit_attn_backend
2730
from vllm.platforms import _Backend, current_platform
2831
from vllm.utils import GiB_bytes, direct_register_custom_op
@@ -247,6 +250,13 @@ def __init__(
247250
"This may be caused by insufficient memory to allocate "
248251
"kv cache.") from e
249252

253+
# for attn backends supporting query quantization
254+
self.query_quant = None
255+
if self.kv_cache_dtype.startswith(
256+
"fp8") and self.attn_backend.supports_quant_query_input:
257+
self.query_quant = QuantFP8(static=True,
258+
group_shape=GroupShape.PER_TENSOR)
259+
250260
def forward(
251261
self,
252262
query: torch.Tensor,
@@ -270,6 +280,15 @@ def forward(
270280
attn_metadata = get_forward_context().attn_metadata
271281
if attn_metadata.enable_kv_scales_calculation:
272282
self.calc_kv_scales(query, key, value)
283+
284+
if self.query_quant is not None:
285+
# quantizing with a simple torch operation enables
286+
# torch.compile to fuse this into previous ops
287+
# which reduces overheads during decoding.
288+
# Otherwise queries are quantized using custom ops
289+
# which causes decoding overheads
290+
query, _ = self.query_quant.forward_native(query, self._q_scale)
291+
273292
if self.use_output:
274293
output_shape = (output_shape
275294
if output_shape is not None else query.shape)
@@ -278,22 +297,6 @@ def forward(
278297
device=query.device)
279298
hidden_size = output_shape[-1]
280299

281-
if envs.VLLM_FUSE_QUERY_QUANT and self.kv_cache_dtype != "auto":
282-
# quantizing with a simple torch operation enables
283-
# torch.compile to fuse this into previous ops
284-
# which reduces overheads during decoding.
285-
# Otherwise queries are quantized using custom ops
286-
# which causes decoding overheads
287-
assert self._q_scale.numel() == 1
288-
if self.kv_cache_dtype in ["fp8", "fp8_e4m3"]:
289-
query = (query / self._q_scale).to(torch.float8_e4m3fn)
290-
elif self.kv_cache_dtype == "fp8_e5m2":
291-
query = (query / self._q_scale).to(torch.float8_e5m2)
292-
else:
293-
raise NotImplementedError(
294-
"VLLM_FUSE_QUERY_QUANT only supported for fp8_e4m3 "
295-
"and fp8_e5m2")
296-
297300
# We skip reshaping query, key and value tensors for the MLA
298301
# backend since these tensors have different semantics and are
299302
# processed differently.

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import torch
99

10-
from vllm import _custom_ops as ops
1110
from vllm import envs
1211
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1312
AttentionMetadata, AttentionType,
@@ -38,6 +37,7 @@
3837
class FlashAttentionBackend(AttentionBackend):
3938

4039
accept_output_buffer: bool = True
40+
supports_quant_query_input: bool = True
4141

4242
@classmethod
4343
def get_supported_dtypes(cls) -> list[torch.dtype]:
@@ -506,17 +506,11 @@ def forward(
506506
)
507507

508508
if self.kv_cache_dtype.startswith("fp8"):
509+
# queries are quantized in the attention layer
509510
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
510511
self.kv_cache_dtype)
511512
key_cache = key_cache.view(dtype)
512513
value_cache = value_cache.view(dtype)
513-
if not envs.VLLM_FUSE_QUERY_QUANT:
514-
num_tokens, num_heads, head_size = query.shape
515-
query, _ = ops.scaled_fp8_quant(
516-
query.reshape(
517-
(num_tokens, num_heads * head_size)).contiguous(),
518-
layer._q_scale)
519-
query = query.reshape((num_tokens, num_heads, head_size))
520514

521515
if not attn_metadata.use_cascade:
522516
cu_seqlens_q = attn_metadata.query_start_loc

0 commit comments

Comments
 (0)