Skip to content

Commit 3df516f

Browse files
Isotr0pycharlifu
authored andcommitted
[V1] Add sliding window support to Flex Attention backend (vllm-project#24089)
Signed-off-by: Isotr0py <mozf@mail2.sysu.edu.cn> Signed-off-by: charlifu <charlifu@amd.com>
1 parent 719ab17 commit 3df516f

File tree

2 files changed

+229
-69
lines changed

2 files changed

+229
-69
lines changed

tests/v1/attention/test_attention_backends.py

Lines changed: 157 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
# SPDX-License-Identifier: Apache-2.0
22
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
33
"""Tests for v1 attention backends without GPUModelRunner dependency."""
4+
from functools import partial
5+
from typing import Optional, Union
46

57
import pytest
68
import torch
9+
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
710

811
from tests.v1.attention.utils import (BatchSpec, _Backend,
912
create_common_attn_metadata,
1013
create_standard_kv_cache_spec,
1114
create_vllm_config,
1215
get_attention_backend)
16+
from vllm.config import ModelConfig
17+
from vllm.platforms import current_platform
1318
from vllm.utils import STR_DTYPE_TO_TORCH_DTYPE, cdiv, is_torch_equal_or_newer
1419
from vllm.v1.attention.backends.utils import (CommonAttentionMetadata,
1520
set_kv_cache_layout)
@@ -183,13 +188,19 @@ def __init__(self, device: torch.device):
183188
self._v_scale_float = 1.0
184189

185190

186-
def run_attention_backend(backend: _Backend, kv_cache_spec: FullAttentionSpec,
187-
layer_names: list[str], vllm_config,
188-
device: torch.device,
189-
common_attn_metadata: CommonAttentionMetadata,
190-
query: torch.Tensor, key: torch.Tensor,
191-
value: torch.Tensor,
192-
kv_cache: torch.Tensor) -> torch.Tensor:
191+
def run_attention_backend(
192+
backend: _Backend,
193+
kv_cache_spec: FullAttentionSpec,
194+
layer_names: list[str],
195+
vllm_config,
196+
device: torch.device,
197+
common_attn_metadata: CommonAttentionMetadata,
198+
query: torch.Tensor,
199+
key: torch.Tensor,
200+
value: torch.Tensor,
201+
kv_cache: torch.Tensor,
202+
sliding_window: Optional[int] = None,
203+
) -> torch.Tensor:
193204
"""Run attention computation using the specified backend's AttentionImpl."""
194205

195206
# Handle special case for FLEX_ATTENTION_SLOW
@@ -253,7 +264,7 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
253264
scale=scale,
254265
num_kv_heads=num_kv_heads,
255266
alibi_slopes=None,
256-
sliding_window=None,
267+
sliding_window=sliding_window,
257268
kv_cache_dtype="auto",
258269
)
259270

@@ -275,13 +286,16 @@ def mock_get_per_layer_parameters(vllm_config, layer_names, impl_cls):
275286
return output
276287

277288

278-
@pytest.mark.parametrize("batch_spec_name", [
279-
"small_decode", "small_prefill", "mixed_small", "medium_decode",
280-
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
281-
"single_decode", "single_prefill"
282-
])
283-
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
284-
def test_backend_correctness(batch_spec_name: str, model: str):
289+
def _test_backend_correctness(
290+
batch_spec: BatchSpec,
291+
model: str,
292+
backend_to_test: list[Union[_Backend, str]],
293+
mask_mod,
294+
*,
295+
block_size: int = 16,
296+
atol: float = 1e-2,
297+
rtol: float = 1e-2,
298+
):
285299
"""
286300
Test that all backends produce similar outputs to a reference implementation
287301
using torch.nn.functional.scaled_dot_product_attention.
@@ -297,9 +311,10 @@ def test_backend_correctness(batch_spec_name: str, model: str):
297311
simulated paged KV cache.
298312
5. Comparing the vLLM backend's output to the ground-truth SDPA output.
299313
"""
300-
batch_spec = BATCH_SPECS[batch_spec_name]
314+
current_platform.seed_everything(42)
301315
vllm_config = create_vllm_config(model_name=model,
302316
max_model_len=max(batch_spec.seq_lens),
317+
block_size=block_size,
303318
num_gpu_blocks=8192)
304319
device = torch.device("cuda:0")
305320

@@ -314,6 +329,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
314329
num_kv_heads = vllm_config.model_config.get_num_kv_heads(
315330
vllm_config.parallel_config)
316331
head_size = vllm_config.model_config.get_head_size()
332+
sliding_window = vllm_config.model_config.get_sliding_window()
317333
dtype = _convert_dtype_to_torch(vllm_config.model_config.dtype)
318334
block_size = vllm_config.cache_config.block_size
319335
scale = 1.0 / (head_size**0.5)
@@ -361,22 +377,21 @@ def test_backend_correctness(batch_spec_name: str, model: str):
361377
# Create causal mask: query token i attends to positions 0 to
362378
# (context_len + i)
363379
kv_len = s_len
364-
offset = context_len
365-
attn_mask = torch.full((q_len, kv_len),
366-
float('-inf'),
367-
device=device,
368-
dtype=dtype)
369-
for i in range(q_len):
370-
attn_mask[i, :offset + i + 1] = 0.0
371-
372-
sdpa_out_i = torch.nn.functional.scaled_dot_product_attention(
373-
q_sdpa_in,
374-
k_sdpa_in,
375-
v_sdpa_in,
376-
attn_mask=attn_mask,
377-
scale=scale,
378-
enable_gqa=True)
379-
# Convert back to (L, H, D)
380+
381+
final_mask_mod = partial(mask_mod, context_len=context_len)
382+
block_mask = create_block_mask(final_mask_mod,
383+
B=None,
384+
H=None,
385+
Q_LEN=q_len,
386+
KV_LEN=kv_len,
387+
device=device)
388+
sdpa_out_i = flex_attention(q_sdpa_in,
389+
k_sdpa_in,
390+
v_sdpa_in,
391+
block_mask=block_mask,
392+
scale=scale,
393+
enable_gqa=True)
394+
380395
all_sdpa_outputs.append(sdpa_out_i.transpose(1, 2).squeeze(0))
381396

382397
# Inputs for vLLM backends are just the new tokens
@@ -412,7 +427,7 @@ def test_backend_correctness(batch_spec_name: str, model: str):
412427
# 4. Run vLLM backends and compare
413428
# Note: flex_attention has known Triton kernel compatibility issues
414429
# with test infrastructures
415-
for backend_name in BACKENDS_TO_TEST:
430+
for backend_name in backend_to_test:
416431
# FlashAttentionm + FlexAttention:
417432
# [2, num_blocks, block_size, num_kv_heads, head_size]
418433
# FlashInfer:
@@ -427,12 +442,19 @@ def test_backend_correctness(batch_spec_name: str, model: str):
427442
2, 3).contiguous().transpose(2, 3)
428443
set_kv_cache_layout("HND")
429444

430-
backend_output = run_attention_backend(backend_name, kv_cache_spec,
431-
["placeholder"], vllm_config,
432-
device, common_attn_metadata,
433-
query_vllm, key_vllm,
434-
value_vllm,
435-
kv_cache_for_backend)
445+
backend_output = run_attention_backend(
446+
backend_name,
447+
kv_cache_spec,
448+
["placeholder"],
449+
vllm_config,
450+
device,
451+
common_attn_metadata,
452+
query_vllm,
453+
key_vllm,
454+
value_vllm,
455+
kv_cache_for_backend,
456+
sliding_window=sliding_window,
457+
)
436458

437459
# Check shape and dtype consistency
438460
assert backend_output.shape == sdpa_output.shape, (
@@ -446,18 +468,102 @@ def test_backend_correctness(batch_spec_name: str, model: str):
446468
f"[{backend_name}] produced non-finite values")
447469

448470
# Check numerical similarity
449-
rtol = 1e-2
450-
atol = 5e-3
451-
452-
max_diff = torch.max(torch.abs(backend_output - sdpa_output)).item()
453-
max_rel_diff = torch.max(
454-
torch.abs(backend_output - sdpa_output) /
455-
torch.abs(sdpa_output)).item()
456-
all_close = torch.allclose(backend_output,
471+
def error_msg(msg: str, backend_name: str):
472+
return (f"[{backend_name}] output differs from SDPA baseline. "
473+
f"{msg}")
474+
475+
torch.testing.assert_close(backend_output,
457476
sdpa_output,
458477
rtol=rtol,
459-
atol=atol)
478+
atol=atol,
479+
msg=partial(error_msg,
480+
backend_name=backend_name))
460481

461-
assert all_close, (
462-
f"[{backend_name}] output differs from SDPA baseline. "
463-
f"Max diff: {max_diff:.6f}, max rel diff: {max_rel_diff:.6f})")
482+
483+
@pytest.mark.parametrize("batch_spec_name", [
484+
"small_decode", "small_prefill", "mixed_small", "medium_decode",
485+
"medium_prefill", "mixed_medium", "large_decode", "large_prefill",
486+
"single_decode", "single_prefill"
487+
])
488+
@pytest.mark.parametrize("model", ["meta-llama/Meta-Llama-3-8B"])
489+
def test_causal_backend_correctness(batch_spec_name: str, model: str):
490+
"""Test backend's correctness with causal attention."""
491+
492+
def causal_mask_mod(
493+
b: torch.Tensor,
494+
h: torch.Tensor,
495+
q_idx: torch.Tensor,
496+
kv_idx: torch.Tensor,
497+
*,
498+
context_len: int,
499+
):
500+
return (q_idx + context_len) >= kv_idx
501+
502+
batch_spec = BATCH_SPECS[batch_spec_name]
503+
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
504+
if is_torch_equal_or_newer("2.9.0.dev0") else [])
505+
SMALL_BLOCK_BACKENDS = [
506+
x for x in BACKENDS_TO_TEST if x not in LARGE_BLOCK_BACKENDS
507+
]
508+
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
509+
causal_mask_mod)
510+
511+
# Fast FlexAttention needs to run with block_size=128
512+
if LARGE_BLOCK_BACKENDS:
513+
_test_backend_correctness(batch_spec,
514+
model,
515+
LARGE_BLOCK_BACKENDS,
516+
causal_mask_mod,
517+
block_size=128)
518+
519+
520+
SLIDING_WINDOW_BACKENDS_TO_TEST = [
521+
_Backend.FLASH_ATTN_VLLM_V1, _Backend.FLEX_ATTENTION,
522+
_Backend.TRITON_ATTN_VLLM_V1, "FLEX_ATTENTION_SLOW"
523+
]
524+
525+
526+
@pytest.mark.parametrize("batch_spec_name", [
527+
"small_decode", "small_prefill", "mixed_medium", "large_decode",
528+
"large_prefill"
529+
])
530+
@pytest.mark.parametrize("model", ["microsoft/Phi-tiny-MoE-instruct"])
531+
def test_sliding_window_backend_correctness(batch_spec_name: str, model: str):
532+
"""Test backend's correctness with sliding window attention."""
533+
534+
def sliding_window_mask_mod(
535+
b: torch.Tensor,
536+
h: torch.Tensor,
537+
q_idx: torch.Tensor,
538+
kv_idx: torch.Tensor,
539+
*,
540+
context_len: int,
541+
sliding_window: int,
542+
):
543+
causal_mask = q_idx + context_len >= kv_idx
544+
window_mask = q_idx + context_len - kv_idx < sliding_window
545+
return causal_mask & window_mask
546+
547+
batch_spec = BATCH_SPECS[batch_spec_name]
548+
model_config = ModelConfig(model=model,
549+
max_model_len=max(batch_spec.seq_lens))
550+
sliding_window = model_config.get_sliding_window()
551+
sliding_window_mask_mod_fn = partial(sliding_window_mask_mod,
552+
sliding_window=sliding_window)
553+
554+
LARGE_BLOCK_BACKENDS = ([_Backend.FLEX_ATTENTION]
555+
if is_torch_equal_or_newer("2.9.0.dev0") else [])
556+
SMALL_BLOCK_BACKENDS = [
557+
x for x in SLIDING_WINDOW_BACKENDS_TO_TEST
558+
if x not in LARGE_BLOCK_BACKENDS
559+
]
560+
_test_backend_correctness(batch_spec, model, SMALL_BLOCK_BACKENDS,
561+
sliding_window_mask_mod_fn)
562+
563+
# Fast FlexAttention needs to run with block_size=128
564+
if LARGE_BLOCK_BACKENDS:
565+
_test_backend_correctness(batch_spec,
566+
model,
567+
LARGE_BLOCK_BACKENDS,
568+
sliding_window_mask_mod_fn,
569+
block_size=128)

0 commit comments

Comments
 (0)