Skip to content

[BugFix][Attention] Fix sliding window attention in V1 giving incorrect results #17574

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 35 additions & 1 deletion vllm/v1/attention/backends/flash_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
AttentionMetadata, AttentionType,
is_quantized_kv_cache)
from vllm.attention.layer import Attention
from vllm.attention.ops.merge_attn_states import merge_attn_states
from vllm.attention.utils.fa_utils import (flash_attn_supports_fp8,
get_flash_attn_version)
from vllm.config import VllmConfig, get_layers_from_vllm_config
from vllm.logger import init_logger
from vllm.platforms import current_platform
from vllm.utils import cdiv
Expand Down Expand Up @@ -273,20 +275,35 @@ def make_local_attention_virtual_batches(
block_table_local


def _get_sliding_window_configs(
vllm_config: VllmConfig) -> set[Optional[tuple[int, int]]]:
"""Get the set of all sliding window configs used in the model."""
sliding_window_configs: set[Optional[tuple[int, int]]] = set()
layers = get_layers_from_vllm_config(vllm_config, Attention)
for layer in layers.values():
assert isinstance(layer.impl, FlashAttentionImpl)
sliding_window_configs.add(layer.impl.sliding_window)
return sliding_window_configs


class FlashAttentionMetadataBuilder:

def __init__(self, runner: "GPUModelRunner"):
model_config = runner.model_config

self.runner = runner
self.aot_schedule = (get_flash_attn_version() == 3)
self.num_heads_q = model_config.get_num_attention_heads(
runner.parallel_config)
self.num_heads_kv = model_config.get_num_kv_heads(
runner.parallel_config)
self.headdim = model_config.get_head_size()
self.page_size = self.runner.block_size

self.aot_schedule = (get_flash_attn_version() == 3)
# Sliding window size to be used with the AOT scheduler will be
# populated on first build() call.
self.aot_sliding_window: Optional[tuple[int, int]] = None

def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput") -> bool:
return False
Expand All @@ -304,6 +321,22 @@ def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()

if self.aot_sliding_window is None:
self.aot_sliding_window = (-1, -1)
# For the AOT scheduler we need the sliding window value to be
# constant for all layers to. We have to populate this on the first
# build() call so the layers are constructed (cannot populate)
# in __init__.
if self.aot_schedule:
sliding_window_configs = _get_sliding_window_configs(
self.runner.vllm_config)
if len(sliding_window_configs) == 1:
sliding_window_config = sliding_window_configs.pop()
if sliding_window_config is not None:
self.aot_sliding_window = sliding_window_config
elif len(sliding_window_configs) > 1:
self.aot_schedule = False

def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
max_seq_len, causal):
if self.aot_schedule:
Expand All @@ -318,6 +351,7 @@ def schedule(batch_size, cu_query_lens, max_query_len, seqlens,
page_size=self.page_size,
cu_seqlens_q=cu_query_lens,
causal=causal,
window_size=self.aot_sliding_window,
)
return None

Expand Down