Skip to content

Enable interleaved sliding window attention models for Transformers backend #18494

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
Enable hybrid attention models for Transformers backend
Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
  • Loading branch information
hmellor committed May 21, 2025
commit 7f63963c1d3e4a0b8040598f185e1af83ee4ad21
18 changes: 10 additions & 8 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -536,13 +536,16 @@ def __post_init__(self) -> None:
self.model, hf_token=self.hf_token, revision=self.revision)
self.dtype = _get_and_verify_dtype(self.hf_config, self.dtype)

interleaved_attn_models = ["gemma2", "gemma3_text", "cohere2"]
# Workaround for Gemma 2 which uses interleaved sliding window
# attention, but it's not specified in its config.
if self.hf_text_config.model_type == "gemma2":
self.hf_text_config.sliding_window_pattern = 2

sliding_window = getattr(self.hf_text_config, "sliding_window", None)
has_interleaved_attention = (sliding_window is not None) and (
isinstance(sliding_window, list) or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, just remembered that this line is for Mistral models. cc @patrickvonplaten do any of your models still use sliding_window as a list?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using a script to download the config of every model in mistralai would suggest no:

Repository Sliding Window Sliding Window Pattern
mistralai/Mistral-7B-v0.1 4096 None
mistralai/Mistral-7B-Instruct-v0.1 4096 None
mistralai/Ministral-8B-Instruct-2410 32768 None

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Plus, according to the docstrings provided in and Mistral and Mixtral, setting sliding_window as list[int] is not supported anyway.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright, then it should be fine to merge. Thanks for looking into this!

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, just realized that the sliding_window is actually set inside params.json... opening a PR to handle the case where it's a list

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't that have been identified from my script? It checked the instantiated config classes rather than examining config.json directly

(self.hf_text_config.model_type in interleaved_attn_models))
sliding_window_pattern = getattr(self.hf_text_config,
"sliding_window_pattern", None)

if (not self.disable_sliding_window and has_interleaved_attention):
if not (self.disable_sliding_window or sliding_window_pattern is None):
if (backend :=
envs.VLLM_ATTENTION_BACKEND) in ("XFORMERS", "FLASHINFER"):
sliding_window_len_min = get_min_sliding_window(
Expand Down Expand Up @@ -1040,8 +1043,7 @@ def verify_with_parallel_config(
if self.use_async_output_proc:
self.use_async_output_proc = False

def get_hf_config_sliding_window(
self) -> Union[Optional[int], list[Optional[int]]]:
def get_hf_config_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled."""

# Some models, like Qwen2 and Qwen1.5, use `use_sliding_window` in
Expand All @@ -1052,7 +1054,7 @@ def get_hf_config_sliding_window(
return None
return getattr(self.hf_text_config, "sliding_window", None)

def get_sliding_window(self) -> Optional[Union[int, list[Optional[int]]]]:
def get_sliding_window(self) -> Optional[int]:
"""Get the sliding window size, or None if disabled.
"""
# If user disables sliding window, return None.
Expand Down
57 changes: 51 additions & 6 deletions vllm/model_executor/models/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""Wrapper around `transformers` models"""
import re
from collections.abc import Iterable
from contextlib import nullcontext
from typing import Literal, Optional, Union

import torch
Expand Down Expand Up @@ -110,6 +111,33 @@ def replace_linear_class(
)


class ConfigOverride:
"""Context manager to temporarily override config attributes."""

def __init__(self, config: PretrainedConfig, **kwargs):
self.config = config
self.kwargs = kwargs
self.kwargs_original = {}
self.kwargs_delete = set()

def __enter__(self):
"""Override config attributes."""
for key, value in self.kwargs.items():
if not hasattr(self.config, key):
self.kwargs_delete.add(key)
self.kwargs_original[key] = getattr(self.config, key, None)
setattr(self.config, key, value)
return self.config

def __exit__(self, exc_type, exc_value, traceback):
"""Restore original config attributes."""
for key, value in self.kwargs_original.items():
if key in self.kwargs_delete:
delattr(self.config, key)
else:
setattr(self.config, key, value)


class TransformersModel(nn.Module):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
Expand All @@ -135,8 +163,17 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.pp_rank = self.pp_group.rank_in_group
self.tp_size = get_tensor_model_parallel_world_size()

# vLLM handles interleaved sliding window attention by creating a new
# interleaved_sliding_window attribute and deleting the sliding_window
# attribute. This breaks the constructors in Transformers so we
# temporarily add the attribute back to construct the model.
config_override = nullcontext()
if hasattr(config, "interleaved_sliding_window"):
config_override = ConfigOverride(
config, sliding_window=config.interleaved_sliding_window)

# Use meta device to delay allocating GPU tensors
with torch.device("meta"):
with torch.device("meta"), config_override:
# FIXME(Isotr0py): We need to refactor this part in the future to
# avoid registering an extra model layer, otherwise we will need a
# weights mapper to rename weights.
Expand Down Expand Up @@ -262,9 +299,17 @@ def create_attention_instances(self) -> dict[int, Attention]:
num_kv_heads = self.model_config.get_num_kv_heads(self.parallel_config)
start, end = get_pp_indices(self.config.num_hidden_layers,
self.pp_rank, self.pp_size)
return {
i:
Attention(

attention_instances = {}
for i in range(start, end):
# Handle interleaved sliding window attention
sliding_window = None
if (hasattr(self.config, "interleaved_sliding_window")
and hasattr(self.config, "sliding_window_pattern")
and ((i + 1) % self.config.sliding_window_pattern > 0)):
sliding_window = self.config.interleaved_sliding_window

attention_instances[i] = Attention(
num_heads=num_heads,
head_size=head_size,
# NOTE: We use Llama scale as default, if it's set by
Expand All @@ -273,9 +318,9 @@ def create_attention_instances(self) -> dict[int, Attention]:
num_kv_heads=num_kv_heads,
cache_config=self.cache_config,
quant_config=self.quant_config,
per_layer_sliding_window=sliding_window,
prefix=f"{i}.attn")
for i in range(start, end)
}
return attention_instances

def init_buffers(self, module: nn.Module):
"""
Expand Down