Skip to content

Commit a257583

Browse files
LiuXiaoxuanPKUbong-furiosa
authored andcommitted
[Kernel] Flashinfer for prefill & decode, with Cudagraph support for decode (vllm-project#4628)
Co-authored-by: LiuXiaoxuanPKU <llilyliupku@gmail.com>, bong-furiosa <bongwon.jang@furiosa.ai>
1 parent 1ad999c commit a257583

File tree

7 files changed

+313
-117
lines changed

7 files changed

+313
-117
lines changed

.buildkite/test-pipeline.yaml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -211,3 +211,6 @@ steps:
211211
- pytest -v -s distributed/test_custom_all_reduce.py
212212
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
213213
- TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=mp pytest -v -s distributed/test_basic_distributed_correctness.py
214+
- pip install https://github.com/flashinfer-ai/flashinfer/releases/download/v0.0.5/flashinfer-0.0.5+cu121torch2.3-cp310-cp310-linux_x86_64.whl
215+
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=facebook/opt-125m DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py
216+
- VLLM_ATTENTION_BACKEND=FLASHINFER TEST_DIST_MODEL=meta-llama/Meta-Llama-3-8B DISTRIBUTED_EXECUTOR_BACKEND=ray pytest -v -s distributed/test_basic_distributed_correctness.py

requirements-test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ sentence-transformers # required for embedding
1919
aiohttp
2020

2121
# quantization
22-
bitsandbytes==0.42.0
22+
bitsandbytes==0.42.0

tests/basic_correctness/test_basic_correctness.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
33
Run `pytest tests/basic_correctness/test_basic_correctness.py`.
44
"""
5-
import os
65
import weakref
76

87
import pytest
@@ -13,7 +12,6 @@
1312
"facebook/opt-125m",
1413
"meta-llama/Llama-2-7b-hf",
1514
]
16-
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
1715

1816

1917
def test_vllm_gc_ed():
@@ -39,10 +37,6 @@ def test_models(
3937
max_tokens: int,
4038
enforce_eager: bool,
4139
) -> None:
42-
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
43-
if backend_by_env_var == "FLASHINFER" and enforce_eager is False:
44-
pytest.skip("Skipping non-eager test for FlashInferBackend.")
45-
4640
with hf_runner(model, dtype=dtype) as hf_model:
4741
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
4842

tests/distributed/test_basic_distributed_correctness.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,6 @@
2121
os.environ["TEST_DIST_MODEL"],
2222
]
2323
DISTRIBUTED_EXECUTOR_BACKEND = "DISTRIBUTED_EXECUTOR_BACKEND"
24-
VLLM_ATTENTION_BACKEND = "VLLM_ATTENTION_BACKEND"
2524

2625

2726
@pytest.mark.skipif(torch.cuda.device_count() < 2,
@@ -39,16 +38,12 @@ def test_models(
3938
) -> None:
4039
distributed_executor_backend = os.getenv(DISTRIBUTED_EXECUTOR_BACKEND)
4140

42-
backend_by_env_var = os.getenv(VLLM_ATTENTION_BACKEND)
43-
enforce_eager = backend_by_env_var == "FLASHINFER"
44-
4541
with hf_runner(model, dtype=dtype) as hf_model:
4642
hf_outputs = hf_model.generate_greedy(example_prompts, max_tokens)
4743

4844
with vllm_runner(model,
4945
dtype=dtype,
5046
tensor_parallel_size=2,
51-
enforce_eager=enforce_eager,
5247
distributed_executor_backend=distributed_executor_backend
5348
) as vllm_model:
5449
vllm_outputs = vllm_model.generate_greedy(example_prompts, max_tokens)

vllm/attention/backends/flashinfer.py

Lines changed: 58 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
from dataclasses import dataclass
22
from typing import Any, Dict, List, Optional, Set, Tuple, Type
33

4-
import flashinfer
4+
try:
5+
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
6+
from flashinfer.prefill import BatchPrefillWithPagedKVCacheWrapper
7+
from vllm_flash_attn import flash_attn_varlen_func
8+
except ImportError:
9+
flash_attn_varlen_func = None
10+
BatchDecodeWithPagedKVCacheWrapper = None
11+
BatchPrefillWithPagedKVCacheWrapper = None
12+
513
import torch
6-
from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7-
from vllm_flash_attn import flash_attn_varlen_func
814

915
from vllm import _custom_ops as ops
1016
from vllm.attention.backends.abstract import (AttentionBackend, AttentionImpl,
@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
6066
# requests only.
6167
max_prefill_seq_len: int
6268

63-
use_cuda_graph: bool = False
69+
use_cuda_graph: bool = True
6470

71+
prefill_wrapper: Optional[BatchPrefillWithPagedKVCacheWrapper] = None
6572
decode_wrapper: Optional[BatchDecodeWithPagedKVCacheWrapper] = None
6673

67-
# Metadata for the prefill stage since we still
68-
# use flash attention for prefill.
74+
# Metadata for the prefill stage
6975
seq_start_loc: Optional[torch.Tensor] = None
76+
query_start_loc: Optional[torch.Tensor] = None
7077
block_tables: Optional[torch.Tensor] = None
7178

72-
# Metadata for the decode stage
73-
# Workspace buffer required by the kernel, the buffer should not
74-
# be allocated/deacollated by the FalshInfermetadata object.
75-
workspace_buffer: Optional[torch.Tensor] = None
7679
# An example for paged_kv_indices, paged_kv_indptr:
7780
# request 1, page indices [0, 5, 8]
7881
# request 2, page indices [1, 6, 7]
@@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
98101
page_size: Optional[int] = None
99102
# The data type of the paged kv cache
100103
data_type: torch.dtype = None
104+
device: torch.device = torch.device("cuda")
101105

102106
def __post_init__(self):
103107
# Refer to
@@ -109,13 +113,35 @@ def __post_init__(self):
109113
f"Only {supported_head_sizes} are supported for head_dim,",
110114
f"received {self.head_dim}.")
111115

112-
# When using flashinfer, we are also creating the FlashInferMetadata,
113-
# which will also call post_init by default, here we want to skip the
114-
# post_init if it's the prefill phase.
115-
if self.num_prefills == 0:
116-
assert self.num_decode_tokens > 0
117-
self.decode_wrapper = flashinfer.BatchDecodeWithPagedKVCacheWrapper(
118-
self.workspace_buffer, "NHD")
116+
def begin_forward(self):
117+
if self.num_prefill_tokens > 0:
118+
if self.paged_kv_indices is None:
119+
return
120+
121+
assert self.prefill_wrapper is not None
122+
assert self.paged_kv_indices is not None
123+
assert self.paged_kv_indptr is not None
124+
assert self.paged_kv_last_page_len is not None
125+
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
126+
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
127+
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
128+
self.device)
129+
self.prefill_wrapper.begin_forward(
130+
self.query_start_loc, self.paged_kv_indptr,
131+
self.paged_kv_indices, self.paged_kv_last_page_len,
132+
self.num_qo_heads, self.num_kv_heads, self.head_dim,
133+
self.page_size)
134+
else:
135+
if not self.use_cuda_graph:
136+
assert self.paged_kv_indices is not None
137+
assert self.paged_kv_indptr is not None
138+
assert self.paged_kv_last_page_len is not None
139+
self.paged_kv_indices = self.paged_kv_indices.to(self.device)
140+
self.paged_kv_indptr = self.paged_kv_indptr.to(self.device)
141+
self.paged_kv_last_page_len = self.paged_kv_last_page_len.to(
142+
self.device)
143+
144+
assert self.decode_wrapper is not None
119145
self.decode_wrapper.begin_forward(
120146
self.paged_kv_indptr,
121147
self.paged_kv_indices,
@@ -133,8 +159,9 @@ def asdict_zerocopy(self,
133159
) -> Dict[str, Any]:
134160
if skip_fields is None:
135161
skip_fields = set()
136-
# We need to skip the decode_wrapper field since it cannot be
162+
# We need to skip the prefill/decode_wrapper field since it cannot be
137163
# broadcasted with nccl when TP is enabled.
164+
skip_fields.add('prefill_wrapper')
138165
skip_fields.add('decode_wrapper')
139166
return super().asdict_zerocopy(skip_fields)
140167

@@ -168,6 +195,7 @@ def __init__(
168195
alibi_slopes: Optional[List[float]],
169196
sliding_window: Optional[int],
170197
kv_cache_dtype: str,
198+
blocksparse_params: Optional[Dict[str, Any]] = None,
171199
) -> None:
172200
self.num_heads = num_heads
173201
self.head_size = head_size
@@ -217,10 +245,14 @@ def forward(
217245
self.kv_cache_dtype,
218246
)
219247

248+
query = query.contiguous(
249+
) # Flashinfer requires query to be contiguous
220250
if prefill_meta := attn_metadata.prefill_metadata:
221-
# Prompt run.
222-
assert prefill_meta.block_tables is not None
223-
if kv_cache is None or prefill_meta.block_tables.numel() == 0:
251+
# We will use flash attention for prefill
252+
# when kv_cache is not provided.
253+
# This happens when vllm runs the profiling to
254+
# determine the number of blocks.
255+
if kv_cache is None:
224256
output = flash_attn_varlen_func(
225257
q=query,
226258
k=key,
@@ -235,13 +267,14 @@ def forward(
235267
alibi_slopes=self.alibi_slopes,
236268
)
237269
else:
238-
raise NotImplementedError(
239-
"Prefix caching is not supported with flashinfer yet.")
270+
assert prefill_meta is not None
271+
assert prefill_meta.prefill_wrapper is not None
272+
output = prefill_meta.prefill_wrapper.forward(query,
273+
kv_cache,
274+
causal=True)
240275
else:
241276
assert attn_metadata.decode_metadata is not None
242277
assert attn_metadata.decode_metadata.decode_wrapper is not None
243-
query = query.contiguous(
244-
) # Flashinfer requires query to be contiguous
245278
output = attn_metadata.decode_metadata.decode_wrapper.forward(
246279
query,
247280
kv_cache,

vllm/attention/selector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,9 @@ def get_attn_backend(
7777
return IpexAttnBackend
7878
elif backend == _Backend.FLASHINFER:
7979
logger.info("Using Flashinfer backend.")
80-
logger.warning("Eager mode is required for the Flashinfer backend. "
81-
"Please make sure --enforce-eager is set.")
80+
logger.warning(("Flashinfer will be stuck on llma-2-7b,"
81+
" please avoid using Flashinfer as the"
82+
"backend when running on llma-2-7b."))
8283
from vllm.attention.backends.flashinfer import FlashInferBackend
8384
return FlashInferBackend
8485
elif backend == _Backend.PALLAS:

0 commit comments

Comments
 (0)