1
1
from dataclasses import dataclass
2
2
from typing import Any , Dict , List , Optional , Set , Tuple , Type
3
3
4
- import flashinfer
4
+ try :
5
+ from flashinfer import BatchDecodeWithPagedKVCacheWrapper
6
+ from flashinfer .prefill import BatchPrefillWithPagedKVCacheWrapper
7
+ from vllm_flash_attn import flash_attn_varlen_func
8
+ except ImportError :
9
+ flash_attn_varlen_func = None
10
+ BatchDecodeWithPagedKVCacheWrapper = None
11
+ BatchPrefillWithPagedKVCacheWrapper = None
12
+
5
13
import torch
6
- from flashinfer import BatchDecodeWithPagedKVCacheWrapper
7
- from vllm_flash_attn import flash_attn_varlen_func
8
14
9
15
from vllm import _custom_ops as ops
10
16
from vllm .attention .backends .abstract import (AttentionBackend , AttentionImpl ,
@@ -60,19 +66,16 @@ class FlashInferMetadata(AttentionMetadata):
60
66
# requests only.
61
67
max_prefill_seq_len : int
62
68
63
- use_cuda_graph : bool = False
69
+ use_cuda_graph : bool = True
64
70
71
+ prefill_wrapper : Optional [BatchPrefillWithPagedKVCacheWrapper ] = None
65
72
decode_wrapper : Optional [BatchDecodeWithPagedKVCacheWrapper ] = None
66
73
67
- # Metadata for the prefill stage since we still
68
- # use flash attention for prefill.
74
+ # Metadata for the prefill stage
69
75
seq_start_loc : Optional [torch .Tensor ] = None
76
+ query_start_loc : Optional [torch .Tensor ] = None
70
77
block_tables : Optional [torch .Tensor ] = None
71
78
72
- # Metadata for the decode stage
73
- # Workspace buffer required by the kernel, the buffer should not
74
- # be allocated/deacollated by the FalshInfermetadata object.
75
- workspace_buffer : Optional [torch .Tensor ] = None
76
79
# An example for paged_kv_indices, paged_kv_indptr:
77
80
# request 1, page indices [0, 5, 8]
78
81
# request 2, page indices [1, 6, 7]
@@ -98,6 +101,7 @@ class FlashInferMetadata(AttentionMetadata):
98
101
page_size : Optional [int ] = None
99
102
# The data type of the paged kv cache
100
103
data_type : torch .dtype = None
104
+ device : torch .device = torch .device ("cuda" )
101
105
102
106
def __post_init__ (self ):
103
107
# Refer to
@@ -109,13 +113,35 @@ def __post_init__(self):
109
113
f"Only { supported_head_sizes } are supported for head_dim," ,
110
114
f"received { self .head_dim } ." )
111
115
112
- # When using flashinfer, we are also creating the FlashInferMetadata,
113
- # which will also call post_init by default, here we want to skip the
114
- # post_init if it's the prefill phase.
115
- if self .num_prefills == 0 :
116
- assert self .num_decode_tokens > 0
117
- self .decode_wrapper = flashinfer .BatchDecodeWithPagedKVCacheWrapper (
118
- self .workspace_buffer , "NHD" )
116
+ def begin_forward (self ):
117
+ if self .num_prefill_tokens > 0 :
118
+ if self .paged_kv_indices is None :
119
+ return
120
+
121
+ assert self .prefill_wrapper is not None
122
+ assert self .paged_kv_indices is not None
123
+ assert self .paged_kv_indptr is not None
124
+ assert self .paged_kv_last_page_len is not None
125
+ self .paged_kv_indices = self .paged_kv_indices .to (self .device )
126
+ self .paged_kv_indptr = self .paged_kv_indptr .to (self .device )
127
+ self .paged_kv_last_page_len = self .paged_kv_last_page_len .to (
128
+ self .device )
129
+ self .prefill_wrapper .begin_forward (
130
+ self .query_start_loc , self .paged_kv_indptr ,
131
+ self .paged_kv_indices , self .paged_kv_last_page_len ,
132
+ self .num_qo_heads , self .num_kv_heads , self .head_dim ,
133
+ self .page_size )
134
+ else :
135
+ if not self .use_cuda_graph :
136
+ assert self .paged_kv_indices is not None
137
+ assert self .paged_kv_indptr is not None
138
+ assert self .paged_kv_last_page_len is not None
139
+ self .paged_kv_indices = self .paged_kv_indices .to (self .device )
140
+ self .paged_kv_indptr = self .paged_kv_indptr .to (self .device )
141
+ self .paged_kv_last_page_len = self .paged_kv_last_page_len .to (
142
+ self .device )
143
+
144
+ assert self .decode_wrapper is not None
119
145
self .decode_wrapper .begin_forward (
120
146
self .paged_kv_indptr ,
121
147
self .paged_kv_indices ,
@@ -133,8 +159,9 @@ def asdict_zerocopy(self,
133
159
) -> Dict [str , Any ]:
134
160
if skip_fields is None :
135
161
skip_fields = set ()
136
- # We need to skip the decode_wrapper field since it cannot be
162
+ # We need to skip the prefill/ decode_wrapper field since it cannot be
137
163
# broadcasted with nccl when TP is enabled.
164
+ skip_fields .add ('prefill_wrapper' )
138
165
skip_fields .add ('decode_wrapper' )
139
166
return super ().asdict_zerocopy (skip_fields )
140
167
@@ -168,6 +195,7 @@ def __init__(
168
195
alibi_slopes : Optional [List [float ]],
169
196
sliding_window : Optional [int ],
170
197
kv_cache_dtype : str ,
198
+ blocksparse_params : Optional [Dict [str , Any ]] = None ,
171
199
) -> None :
172
200
self .num_heads = num_heads
173
201
self .head_size = head_size
@@ -217,10 +245,14 @@ def forward(
217
245
self .kv_cache_dtype ,
218
246
)
219
247
248
+ query = query .contiguous (
249
+ ) # Flashinfer requires query to be contiguous
220
250
if prefill_meta := attn_metadata .prefill_metadata :
221
- # Prompt run.
222
- assert prefill_meta .block_tables is not None
223
- if kv_cache is None or prefill_meta .block_tables .numel () == 0 :
251
+ # We will use flash attention for prefill
252
+ # when kv_cache is not provided.
253
+ # This happens when vllm runs the profiling to
254
+ # determine the number of blocks.
255
+ if kv_cache is None :
224
256
output = flash_attn_varlen_func (
225
257
q = query ,
226
258
k = key ,
@@ -235,13 +267,14 @@ def forward(
235
267
alibi_slopes = self .alibi_slopes ,
236
268
)
237
269
else :
238
- raise NotImplementedError (
239
- "Prefix caching is not supported with flashinfer yet." )
270
+ assert prefill_meta is not None
271
+ assert prefill_meta .prefill_wrapper is not None
272
+ output = prefill_meta .prefill_wrapper .forward (query ,
273
+ kv_cache ,
274
+ causal = True )
240
275
else :
241
276
assert attn_metadata .decode_metadata is not None
242
277
assert attn_metadata .decode_metadata .decode_wrapper is not None
243
- query = query .contiguous (
244
- ) # Flashinfer requires query to be contiguous
245
278
output = attn_metadata .decode_metadata .decode_wrapper .forward (
246
279
query ,
247
280
kv_cache ,
0 commit comments