Skip to content

Add FlexAttention to V1 #16078

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 7, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 93 additions & 0 deletions tests/kernels/test_flex_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# SPDX-License-Identifier: Apache-2.0
"""Integration tests for FlexAttention backend vs default backend"""

import random

import numpy as np
import pytest
import torch
from packaging import version

from vllm import LLM, SamplingParams

TORCH_VERSION = version.parse(torch.__version__)
MINIMUM_TORCH_VERSION = version.parse("2.7.0")


def set_seed(seed):
"""Set seeds for reproducibility"""
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)


@pytest.mark.skipif(
not torch.cuda.is_available() or TORCH_VERSION < MINIMUM_TORCH_VERSION,
reason="CUDA not available or PyTorch version < 2.7",
)
def test_flex_attention_vs_default_backend(monkeypatch):
"""Test that FlexAttention produces the same outputs as the default backend.

This test compares the outputs from the FlexAttention backend with
the default backend, ensuring they are identical when using the same seed.
"""
model_name = "Qwen/Qwen2.5-1.5B-Instruct"
seed = 42
max_tokens = 32
prompts = [
"Hello, my name is",
"The president of the United States is",
"The capital of France is",
]

sampling_params = SamplingParams(temperature=0.0,
top_p=1.0,
seed=seed,
max_tokens=max_tokens)

# Run with flex attention
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ATTENTION_BACKEND", "FLEX_ATTENTION")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")

set_seed(seed)

llm_flex = LLM(
model_name,
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
)
output_flex = llm_flex.generate(prompts, sampling_params)

# Run with default backend
with monkeypatch.context() as m:
m.setenv("VLLM_USE_V1", "1")
m.setenv("VLLM_ENABLE_V1_MULTIPROCESSING", "0")
set_seed(seed)
llm_default = LLM(
model_name,
tensor_parallel_size=1,
num_gpu_blocks_override=128,
enforce_eager=True,
)
output_default = llm_default.generate(prompts, sampling_params)

# Compare outputs from both backends
for i, (flex_result,
default_result) in enumerate(zip(output_flex, output_default)):
prompt = prompts[i]
flex_text = flex_result.outputs[0].text
default_text = default_result.outputs[0].text

assert flex_text == default_text, (
f"FlexAttention output doesn't match default for: {prompt!r}\n"
f"FlexAttention: {flex_text!r}\n"
f"Default: {default_text!r}")


if __name__ == "__main__":
pytest.main([__file__])
1 change: 1 addition & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1409,6 +1409,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
"FLASHINFER_VLLM_V1",
"ROCM_AITER_MLA",
"TORCH_SDPA_VLLM_V1",
"FLEX_ATTENTION",
]
if (envs.is_set("VLLM_ATTENTION_BACKEND")
and envs.VLLM_ATTENTION_BACKEND not in V1_BACKENDS):
Expand Down
3 changes: 3 additions & 0 deletions vllm/platforms/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,9 @@ def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
if selected_backend == _Backend.FLASHINFER:
logger.info_once("Using FlashInfer backend on V1 engine.")
return "vllm.v1.attention.backends.flashinfer.FlashInferBackend"
if selected_backend == _Backend.FLEX_ATTENTION:
logger.info("Using FlexAttenion backend on V1 engine.")
return "vllm.v1.attention.backends.flex_attention.FlexAttentionBackend" # noqa: E501
if selected_backend == _Backend.TRITON_ATTN_VLLM_V1:
logger.info_once("Using Triton backend on V1 engine.")
return ("vllm.v1.attention.backends."
Expand Down
1 change: 1 addition & 0 deletions vllm/platforms/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class _Backend(enum.Enum):
BLOCK_SPARSE_FLASH_ATTN = enum.auto()
DUAL_CHUNK_FLASH_ATTN = enum.auto()
NO_ATTENTION = enum.auto()
FLEX_ATTENTION = enum.auto()


class PlatformEnum(enum.Enum):
Expand Down
Loading