Skip to content

[Kernel] Raise verbose error and consolidate num_heads/num_kv_heads divisibility check #19339

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions tests/kernels/attention/test_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from tests.kernels.allclose_default import get_default_atol, get_default_rtol
from tests.kernels.utils import opcheck
from vllm import _custom_ops as ops
from vllm.attention.layer import Attention, MultiHeadAttention
from vllm.platforms import current_platform
from vllm.utils import get_max_shared_memory_bytes

Expand Down Expand Up @@ -506,3 +507,18 @@ def test_multi_query_kv_attention_with_alibi(
device,
use_alibi=True,
)


@pytest.mark.parametrize("attention_cls", [Attention, MultiHeadAttention])
def test_num_heads_not_divisble_by_num_kv_heads(attention_cls: type) -> None:
head_size = 64
scale = float(1.0 / (head_size**0.5))
num_heads = 16
num_kv_heads = 5
with pytest.raises(AssertionError):
_ = attention_cls(
num_heads=num_heads,
head_size=head_size,
scale=scale,
num_kv_heads=num_kv_heads,
)
4 changes: 1 addition & 3 deletions vllm/attention/backends/blocksparse_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __post_init__(self):
assert self.block_size > 0
assert self.local_blocks >= 0
assert self.vert_stride >= 1
assert self.num_heads % self.num_kv_heads == 0

tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
Expand Down Expand Up @@ -329,9 +328,8 @@ def __init__(
self.head_size = head_size
self.scale = float(scale)
self.alibi_slopes = alibi_slopes
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_heads = num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.local_blocks = self.blocksparse_params.local_blocks
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/dual_chunk_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def __init__(
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if sliding_window is not None:
# NOTE(woosuk): flash-attn's sliding window does not work with
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,7 +654,6 @@ def __init__(
logits_soft_cap = 0
self.logits_soft_cap = logits_soft_cap

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -957,7 +957,6 @@ def __init__(
self.kv_cache_dtype = kv_cache_dtype
self.logits_soft_cap = logits_soft_cap

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if attn_type != AttentionType.DECODER:
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ def __init__(
alibi_slopes_tensor = torch.tensor(alibi_slopes,
dtype=torch.bfloat16)
self.alibi_slopes = alibi_slopes_tensor
assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if self.prefill_impl == 'fsdpa':
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/ipex_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,6 @@ def __init__(
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.sliding_window is not None)
if logits_soft_cap is None:
Expand Down
3 changes: 1 addition & 2 deletions vllm/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,8 @@ def __init__(
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads
self.num_kv_heads = num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.logits_soft_cap = logits_soft_cap
if head_size % 128 != 0:
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/rocm_flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,7 +527,6 @@ def __init__(
if sliding_window is not None else (-1, -1))
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

self.paged_attn_module = _get_paged_attn_module()
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/torch_sdpa.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,6 @@ def __init__(
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
self.need_mask = (self.alibi_slopes is not None
or self.sliding_window is not None)
Expand Down
1 change: 0 additions & 1 deletion vllm/attention/backends/xformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,6 @@ def __init__(
self.sliding_window = sliding_window
self.kv_cache_dtype = kv_cache_dtype

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

supported_head_sizes = PagedAttention.get_supported_head_sizes()
Expand Down
7 changes: 6 additions & 1 deletion vllm/attention/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ def __init__(
calculate_kv_scales = False
if num_kv_heads is None:
num_kv_heads = num_heads
assert num_heads % num_kv_heads == 0, \
f"num_heads ({num_heads}) is not " \
f"divisible by num_kv_heads ({num_kv_heads})"

# The default k/v_scale is set to 1.0. This is ignored
# when kv-cache is not fp8, and should be used with
Expand Down Expand Up @@ -291,7 +294,9 @@ def __init__(
self.scale = scale
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
assert self.num_heads % self.num_kv_heads == 0, \
f"num_heads ({self.num_heads}) is not " \
f"divisible by num_kv_heads ({self.num_kv_heads})"
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

dtype = torch.get_default_dtype()
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,7 +538,6 @@ def __init__(
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = FlashAttentionBackend.get_supported_head_sizes()
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/attention/backends/flashinfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,6 @@ def __init__(
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if attn_type != AttentionType.DECODER:
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/attention/backends/flex_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,7 +380,6 @@ def __init__(
raise NotImplementedError(
"FlexAttention does not support logits soft cap yet.")

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if kv_sharing_target_layer_name is not None:
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/attention/backends/pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def __init__(
self.logits_soft_cap = logits_soft_cap
self.kv_sharing_target_layer_name = kv_sharing_target_layer_name

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads
if head_size % 128 != 0:
raise NotImplementedError("Head size must be a multiple of 128.")
Expand Down
1 change: 0 additions & 1 deletion vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,6 @@ def __init__(

self.use_irope = use_irope

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

support_head_sizes = TritonAttentionBackend.get_supported_head_sizes()
Expand Down