Skip to content

Commit fac11fa

Browse files
jmkueblerxuebwang-amd
authored andcommitted
[torch.compile] Make Query Quantization Fusable (vllm-project#24914)
Signed-off-by: Jonas Kuebler <kuebj@amazon.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent e1ef2a0 commit fac11fa

File tree

3 files changed

+32
-8
lines changed

3 files changed

+32
-8
lines changed

vllm/attention/backends/abstract.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,14 @@ class AttentionBackend(ABC):
3131
# makes sure the output tensor is allocated inside the cudagraph.
3232
accept_output_buffer: bool = False
3333

34+
# Whether this backend supports receiving pre-quantized query input.
35+
# If True, the attention layer will handle query quantization instead
36+
# of the backend, allowing torch.compile to fuse quantization with
37+
# previous operations.
38+
# Needs to be worked through for all backends
39+
# https://github.com/vllm-project/vllm/issues/25584
40+
supports_quant_query_input: bool = False
41+
3442
@staticmethod
3543
@abstractmethod
3644
def get_name() -> str:

vllm/attention/layer.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,10 @@
2222
from vllm.model_executor.layers.linear import UnquantizedLinearMethod
2323
from vllm.model_executor.layers.quantization.base_config import (
2424
QuantizationConfig)
25+
from vllm.model_executor.layers.quantization.input_quant_fp8 import QuantFP8
2526
from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod
27+
from vllm.model_executor.layers.quantization.utils.quant_utils import (
28+
GroupShape)
2629
from vllm.model_executor.models.vision import get_vit_attn_backend
2730
from vllm.platforms import _Backend, current_platform
2831
from vllm.utils import GiB_bytes, direct_register_custom_op
@@ -247,6 +250,13 @@ def __init__(
247250
"This may be caused by insufficient memory to allocate "
248251
"kv cache.") from e
249252

253+
# for attn backends supporting query quantization
254+
self.query_quant = None
255+
if self.kv_cache_dtype.startswith(
256+
"fp8") and self.attn_backend.supports_quant_query_input:
257+
self.query_quant = QuantFP8(static=True,
258+
group_shape=GroupShape.PER_TENSOR)
259+
250260
def forward(
251261
self,
252262
query: torch.Tensor,
@@ -270,11 +280,22 @@ def forward(
270280
attn_metadata = get_forward_context().attn_metadata
271281
if attn_metadata.enable_kv_scales_calculation:
272282
self.calc_kv_scales(query, key, value)
283+
284+
output_dtype = query.dtype
285+
if self.query_quant is not None:
286+
# quantizing with a simple torch operation enables
287+
# torch.compile to fuse this into previous ops
288+
# which reduces overheads during decoding.
289+
# Otherwise queries are quantized using custom ops
290+
# which causes decoding overheads
291+
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
292+
query, _ = self.query_quant(query, self._q_scale)
293+
273294
if self.use_output:
274295
output_shape = (output_shape
275296
if output_shape is not None else query.shape)
276297
output = torch.zeros(output_shape,
277-
dtype=query.dtype,
298+
dtype=output_dtype,
278299
device=query.device)
279300
hidden_size = output_shape[-1]
280301
# We skip reshaping query, key and value tensors for the MLA

vllm/v1/attention/backends/flash_attn.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import torch
99

10-
from vllm import _custom_ops as ops
1110
from vllm import envs
1211
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
1312
AttentionMetadata, AttentionType,
@@ -38,6 +37,7 @@
3837
class FlashAttentionBackend(AttentionBackend):
3938

4039
accept_output_buffer: bool = True
40+
supports_quant_query_input: bool = True
4141

4242
@classmethod
4343
def get_supported_dtypes(cls) -> list[torch.dtype]:
@@ -506,16 +506,11 @@ def forward(
506506
)
507507

508508
if self.kv_cache_dtype.startswith("fp8"):
509+
# queries are quantized in the attention layer
509510
dtype = FlashAttentionBackend.get_fp8_dtype_for_flashattn(
510511
self.kv_cache_dtype)
511512
key_cache = key_cache.view(dtype)
512513
value_cache = value_cache.view(dtype)
513-
num_tokens, num_heads, head_size = query.shape
514-
query, _ = ops.scaled_fp8_quant(
515-
query.reshape(
516-
(num_tokens, num_heads * head_size)).contiguous(),
517-
layer._q_scale)
518-
query = query.reshape((num_tokens, num_heads, head_size))
519514

520515
if not attn_metadata.use_cascade:
521516
cu_seqlens_q = attn_metadata.query_start_loc

0 commit comments

Comments
 (0)