Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 41 additions & 19 deletions tests/kernels/attention/test_attention_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@ def clear_cache():

# Define MLA and non-MLA backends separately
DEVICE_MLA_BACKENDS = {
"cuda": ["TRITON_MLA", "FLASHMLA", "FLASH_ATTN_MLA", "CUTLASS_MLA"],
"cuda": [
"TRITON_MLA", "FLASHMLA", "FLASHINFER_MLA", "FLASH_ATTN_MLA",
"CUTLASS_MLA"
],
"hip": ["TRITON_MLA", "ROCM_AITER_MLA"],
"cpu": [],
}
Expand Down Expand Up @@ -90,8 +93,8 @@ def test_env(

with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float16, torch.float16,
block_size, False)
backend = get_attn_backend(16, torch.float16, None, block_size,
False)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"

elif device == "hip":
Expand All @@ -109,7 +112,7 @@ def test_env(
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -120,7 +123,7 @@ def test_env(
with pytest.raises(ValueError) as exc_info:
get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -130,7 +133,7 @@ def test_env(
# Valid backend-block_size combination
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -139,7 +142,7 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -153,6 +156,8 @@ def test_env(
# CUDA MLA backend logic:
# - CUTLASS_MLA: only supported with block_size == 128
# and Blackwell GPUs (SM 10.0), V1 only
# - FLASHINFER_MLA: only supported on Blackwell GPUs
# (SM 10.0+), V1 only
# - FLASHMLA: only supported with block_size == 64
# - FLASH_ATTN_MLA: V1 only
# - TRITON_MLA: fallback for other cases
Expand All @@ -169,12 +174,31 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "CUTLASS_MLA_VLLM_V1"
assert backend.get_name() == expected
elif name == "FLASHINFER_MLA":
if not use_v1:
# FlashInfer MLA only supported on V1 engine
pytest.skip(
"FlashInfer MLA only supported on V1 engine")
elif block_size not in [32, 64]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

QQ: Is there a reason we don't have the attention backend define what block size it supports?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It definitely should -- I can make that change in a separate PR to standardize the logic now that we have 3 different backends with a restriction like this. That would be cleaner than what we have now where it's hardcoded in a couple different places

# FlashInfer MLA only supports block_size 32 or 64
pytest.skip(
"FlashInfer MLA only supports block_size 32 "
"or 64")
else:
backend = get_attn_backend(16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
expected = "FLASHINFER_MLA"
assert backend.get_name() == expected
elif name == "FLASHMLA":
if block_size != 64:
# FlashMLA only supports block_size == 64
Expand All @@ -189,7 +213,7 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -204,7 +228,7 @@ def test_env(
else:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -214,7 +238,7 @@ def test_env(
# TRITON_MLA or other fallback
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -224,7 +248,7 @@ def test_env(
elif name == "FLASHINFER":
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -233,7 +257,7 @@ def test_env(
else:
backend = get_attn_backend(32,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -243,7 +267,7 @@ def test_env(
if use_v1:
backend = get_attn_backend(16,
torch.float16,
torch.float16,
None,
block_size,
False,
use_mla=use_mla)
Expand All @@ -269,15 +293,13 @@ def test_fp32_fallback(

with patch("vllm.attention.selector.current_platform",
CpuPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
backend = get_attn_backend(16, torch.float32, None, 16, False)
assert backend.get_name() == "TORCH_SDPA_VLLM_V1"

elif device == "cuda":
with patch("vllm.attention.selector.current_platform",
CudaPlatform()):
backend = get_attn_backend(16, torch.float32, torch.float32,
16, False)
backend = get_attn_backend(16, torch.float32, None, 16, False)
assert (backend.get_name() == "FLEX_ATTENTION"
if use_v1 else "XFORMERS")

Expand Down Expand Up @@ -331,7 +353,7 @@ def test_flash_attn(monkeypatch: pytest.MonkeyPatch):
assert backend.get_name() != STR_FLASH_ATTN_VAL

# Attention-free models should bypass env and use PlaceholderAttention
backend = get_attn_backend(16, torch.float16, torch.float16, 16, True)
backend = get_attn_backend(16, torch.float16, None, 16, True)
assert backend.get_name() != STR_FLASH_ATTN_VAL


Expand Down
2 changes: 2 additions & 0 deletions tests/v1/attention/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,8 @@ def get_attention_backend(backend_name: _Backend):
"vllm.v1.attention.backends.mla.flashmla.FlashMLABackend",
_Backend.FLASH_ATTN_MLA:
"vllm.v1.attention.backends.mla.flashattn_mla.FlashAttnMLABackend",
_Backend.FLASHINFER_MLA:
"vllm.v1.attention.backends.mla.flashinfer_mla.FlashInferMLABackend",
_Backend.TRITON_MLA_VLLM_V1:
"vllm.v1.attention.backends.mla.triton_mla.TritonMLABackend",
}
Expand Down