Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion tests/compile/test_fusion_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,9 @@ def test_attention_quant_pattern(
]
if any(attn_fusion_supported):
# Check quantization ops in the graph before and after fusion
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=True)
# Note: fully_replaced=False because query quant ops remain in graph.
# Only output quant ops are fused into attention.
test_backend.check_before_ops([QUANT_OPS[quant_key]], fully_replaced=False)

# access the underlying `AttnFusionPass` on the `LazyInitPass`
assert attn_pass.pass_.matched_count == sum(attn_fusion_supported)
Expand Down
24 changes: 16 additions & 8 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,6 @@ class AttentionBackend(ABC):
# makes sure the output tensor is allocated inside the cudagraph.
accept_output_buffer: bool = False

# Whether this backend supports receiving pre-quantized query input.
# If True, the attention layer will handle query quantization instead
# of the backend, allowing torch.compile to fuse quantization with
# previous operations.
# Needs to be worked through for all backends
# https://github.com/vllm-project/vllm/issues/25584
supports_quant_query_input: bool = False

@staticmethod
@abstractmethod
def get_name() -> str:
Expand Down Expand Up @@ -199,6 +191,22 @@ def fused_output_quant_supported(self, quant_key: QuantKey):
"""
return False

def supports_quant_query_input(self) -> bool:
"""
Check if this attention implementation supports pre-quantized query input.

When True, the attention layer will quantize queries before passing them
to this backend, allowing torch.compile to fuse the quantization with
previous operations. This is typically supported when using FP8 KV cache
with compatible attention kernels (e.g., TRT-LLM).
TODO add support to more backends:
https://github.com/vllm-project/vllm/issues/25584

Returns:
bool: True if the implementation can accept pre-quantized queries.
"""
return False


class MLAAttentionImpl(AttentionImpl[T], Generic[T]):
@abstractmethod
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from vllm.platforms import current_platform
from vllm.utils import GiB_bytes, direct_register_custom_op

FP8_DTYPE = current_platform.fp8_dtype()
logger = init_logger(__name__)
USE_XFORMERS_OPS = None

Expand Down Expand Up @@ -304,7 +305,7 @@ def __init__(
self.query_quant = None
if (
self.kv_cache_dtype.startswith("fp8")
and self.attn_backend.supports_quant_query_input
and self.impl.supports_quant_query_input()
):
self.query_quant = QuantFP8(static=True, group_shape=GroupShape.PER_TENSOR)

Expand All @@ -329,7 +330,6 @@ def forward(
"""
if self.calculate_kv_scales:
torch.ops.vllm.maybe_calc_kv_scales(query, key, value, self.layer_name)

output_dtype = query.dtype
if self.query_quant is not None:
# quantizing with a simple torch operation enables
Expand All @@ -338,7 +338,10 @@ def forward(
# Otherwise queries are quantized using custom ops
# which causes decoding overheads
assert self.kv_cache_dtype in {"fp8", "fp8_e4m3"}
query, _ = self.query_quant(query, self._q_scale)

# check if query quantization is supported
if self.impl.supports_quant_query_input():
query, _ = self.query_quant(query, self._q_scale)

if self.use_output:
output_shape = output_shape if output_shape is not None else query.shape
Expand Down
4 changes: 3 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@

class FlashAttentionBackend(AttentionBackend):
accept_output_buffer: bool = True
supports_quant_query_input: bool = True

@classmethod
def get_supported_dtypes(cls) -> list[torch.dtype]:
Expand Down Expand Up @@ -494,6 +493,9 @@ def __init__(
"heads in the layer"
)

def supports_quant_query_input(self) -> bool:
return True

def forward(
self,
layer: torch.nn.Module,
Expand Down
22 changes: 12 additions & 10 deletions vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
from flashinfer.prefill import trtllm_batch_context_with_kv_cache
from flashinfer.utils import FP4Tensor

from vllm import _custom_ops as ops
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
Expand Down Expand Up @@ -828,6 +827,12 @@ def fused_output_quant_supported(self, quant_key: QuantKey):
and quant_key in (kFp8StaticTensorSym, kNvfp4Quant)
)

def supports_quant_query_input(self) -> bool:
if flashinfer_disable_q_quantization():
return False

return self.support_trtllm_attn

def forward(
self,
layer: torch.nn.Module,
Expand Down Expand Up @@ -859,6 +864,12 @@ def forward(
# Profiling run.
return output.fill_(0)

# Ensure query dtype matches the expected dtype from attention metadata
assert attn_metadata.q_data_type == query.dtype, (
f"Query dtype mismatch: expected {attn_metadata.q_data_type}, "
f"got {query.dtype}"
)

if self.bmm1_scale is None:
self.bmm1_scale = layer._q_scale_float * layer._k_scale_float * self.scale

Expand Down Expand Up @@ -899,15 +910,6 @@ def forward(
elif output.dtype == FP4_DTYPE:
self.o_sf_scale = layer._o_scale_float

# Insert FP8 quant for query
if attn_metadata.q_data_type == FP8_DTYPE:
num_tokens, num_heads, head_size = query.shape
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))

# IMPORTANT!
# NOTE(woosuk): With piece-wise CUDA graphs, this method is executed in
# eager-mode PyTorch. Thus, we need to be careful about any CPU overhead
Expand Down
18 changes: 3 additions & 15 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,6 @@
)
from vllm.v1.kv_cache_interface import AttentionSpec

if current_platform.is_cuda_alike():
from vllm import _custom_ops as ops
elif current_platform.is_xpu():
from vllm._ipex_ops import ipex_ops as ops

logger = init_logger(__name__)


Expand Down Expand Up @@ -210,6 +205,9 @@ class TritonAttentionImpl(AttentionImpl):
def fused_output_quant_supported(self, quant_key: QuantKey):
return quant_key == kFp8StaticTensorSym

def supports_quant_query_input(self) -> bool:
return current_platform.is_cuda()

def __init__(
self,
num_heads: int,
Expand Down Expand Up @@ -338,19 +336,9 @@ def forward(
if key_cache.dtype != self.fp8_dtype:
key_cache = key_cache.view(self.fp8_dtype)
value_cache = value_cache.view(self.fp8_dtype)
num_tokens, num_heads, head_size = query.shape
assert layer._q_scale_float == 1.0, (
"A non 1.0 q_scale is not currently supported."
)
if current_platform.is_cuda():
# Skip Q quantization on ROCm and XPU, enable this on cuda
# only, since dequantizing back to f32 in the attention kernel
# is not supported.
query, _ = ops.scaled_fp8_quant(
query.reshape((num_tokens, num_heads * head_size)).contiguous(),
layer._q_scale,
)
query = query.reshape((num_tokens, num_heads, head_size))

cu_seqlens_q = attn_metadata.query_start_loc
seqused_k = attn_metadata.seq_lens
Expand Down