Skip to content

[V1][Bugfix] Standardize quantized kv cache rejection for attention backends #14221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions vllm/attention/backends/abstract.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,7 @@ def forward(
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
raise NotImplementedError


def is_quantized_kv_cache(kv_cache_dtype: str) -> bool:
return kv_cache_dtype != "auto"
9 changes: 8 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@
import torch

from vllm import _custom_ops as ops
# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import (
PAD_SLOT_ID, CommonAttentionState, compute_slot_mapping,
compute_slot_mapping_start_idx, get_flash_attn_version,
Expand Down Expand Up @@ -626,6 +630,9 @@ def __init__(
self.sliding_window = ((sliding_window - 1,
0) if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
Expand Down Expand Up @@ -207,6 +208,10 @@ def __init__(
"are not implemented for "
"FlashMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand All @@ -215,8 +220,6 @@ def _forward_decode(
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")

decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
Expand Down
7 changes: 6 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.hpu_paged_attn import (HPUPagedAttention,
HPUPagedAttentionMetadata)
Expand Down Expand Up @@ -158,6 +159,10 @@ def __init__(
"are not implemented for "
"HPUAttentionImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"HPUAttention with FP8 KV cache not yet supported")

def forward(
self,
layer: AttentionLayer,
Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from vllm._ipex_ops import ipex_ops
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.paged_attn import (PagedAttention,
PagedAttentionMetadata)
Expand Down Expand Up @@ -145,7 +146,7 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError(
"IPEX backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
Expand Down
5 changes: 3 additions & 2 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import CommonAttentionState


Expand Down Expand Up @@ -119,7 +120,7 @@ def __init__(
raise NotImplementedError("Alibi slopes is not supported.")
if sliding_window is not None:
raise NotImplementedError("Sliding window is not supported.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError("FP8 KV cache dtype is not supported.")
if blocksparse_params is not None:
raise NotImplementedError("Blocksparse is not supported.")
Expand Down
8 changes: 6 additions & 2 deletions vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,15 @@
import torch
from torch.nn.functional import scaled_dot_product_attention

# yapf conflicts with isort for this block
# yapf: disable
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionLayer,
AttentionMetadata,
AttentionMetadataBuilder,
AttentionType)
AttentionType,
is_quantized_kv_cache)
# yapf: enable
from vllm.attention.backends.utils import CommonAttentionState
from vllm.attention.ops.ipex_attn import PagedAttention
from vllm.attention.ops.paged_attn import PagedAttentionMetadata
Expand Down Expand Up @@ -427,7 +431,7 @@ def __init__(
raise ValueError(
f"Head size {head_size} is not supported by PagedAttention. "
f"Supported head sizes are: {supported_head_sizes}.")
if kv_cache_dtype != "auto":
if is_quantized_kv_cache(kv_cache_dtype):
raise NotImplementedError(
"Torch SDPA backend does not support FP8 KV cache. "
"Please use xFormers backend instead.")
Expand Down
9 changes: 6 additions & 3 deletions vllm/attention/backends/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
Expand Down Expand Up @@ -58,6 +59,10 @@ def __init__(
"are not implemented for "
"TritonMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand All @@ -66,8 +71,6 @@ def _forward_decode(
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")

decode_meta = attn_metadata.decode_metadata
assert decode_meta is not None
Expand Down
6 changes: 5 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import torch

from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType)
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.logger import init_logger
Expand Down Expand Up @@ -180,6 +181,9 @@ def __init__(
else:
self.sliding_window = (sliding_window - 1, 0)
self.kv_cache_dtype = kv_cache_dtype
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashAttention V1 with FP8 KV cache not yet supported")
if logits_soft_cap is None:
# In flash-attn, setting logits_soft_cap as 0 means no soft cap.
logits_soft_cap = 0
Expand Down
10 changes: 6 additions & 4 deletions vllm/v1/attention/backends/mla/flashmla.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
Expand Down Expand Up @@ -115,6 +116,10 @@ def __init__(
"are not implemented for "
"FlashMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"FlashMLA V1 with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand All @@ -125,9 +130,6 @@ def _forward_decode(
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None

if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")

q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)

Expand Down
7 changes: 6 additions & 1 deletion vllm/v1/attention/backends/mla/triton_mla.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@

import torch

from vllm.attention.backends.abstract import AttentionType
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
Expand Down Expand Up @@ -61,6 +62,10 @@ def __init__(
"are not implemented for "
"TritonMLAImpl")

if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"TritonMLA V1 with FP8 KV cache not yet supported")

def _forward_decode(
self,
q_nope: torch.Tensor,
Expand Down