3
3
import pytest
4
4
import torch
5
5
6
- import vllm .attention .backends .flash_attn # noqa: F401
7
- from tests .kernels .utils import opcheck
8
6
from vllm .utils import seed_everything
7
+ from vllm .vllm_flash_attn import (flash_attn_varlen_func ,
8
+ flash_attn_with_kvcache )
9
9
10
10
NUM_HEADS = [(4 , 4 ), (8 , 2 ), (16 , 2 )]
11
11
HEAD_SIZES = [128 , 256 ]
@@ -112,36 +112,17 @@ def test_flash_attn_with_paged_kv(
112
112
(num_seqs , max_num_blocks_per_seq ),
113
113
dtype = torch .int32 )
114
114
115
- output = torch . ops . vllm . flash_attn_with_kvcache (
116
- decode_query = query .unsqueeze (1 ),
117
- key_cache = key_cache ,
118
- value_cache = value_cache ,
115
+ output = flash_attn_with_kvcache (
116
+ q = query .unsqueeze (1 ),
117
+ k_cache = key_cache ,
118
+ v_cache = value_cache ,
119
119
softmax_scale = scale ,
120
120
causal = True ,
121
121
block_table = block_tables ,
122
122
cache_seqlens = kv_lens_tensor ,
123
123
softcap = soft_cap if soft_cap is not None else 0 ,
124
124
).squeeze (1 )
125
125
126
- if num_blocks <= 2048 :
127
- test_utils = ["test_faketensor" , "test_schema" ]
128
- else :
129
- test_utils = ["test_faketensor" ]
130
-
131
- opcheck (torch .ops .vllm .flash_attn_with_kvcache ,
132
- args = tuple (),
133
- kwargs = dict (
134
- decode_query = query .unsqueeze (1 ),
135
- key_cache = key_cache ,
136
- value_cache = value_cache ,
137
- softmax_scale = scale ,
138
- causal = True ,
139
- block_table = block_tables ,
140
- cache_seqlens = kv_lens_tensor ,
141
- softcap = soft_cap if soft_cap is not None else 0 ,
142
- ),
143
- test_utils = test_utils )
144
-
145
126
ref_output = ref_paged_attn (
146
127
query = query ,
147
128
key_cache = key_cache ,
@@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
213
194
(num_seqs , max_num_blocks_per_seq ),
214
195
dtype = torch .int32 )
215
196
216
- output = torch . ops . vllm . flash_attn_varlen_func (
197
+ output = flash_attn_varlen_func (
217
198
q = query ,
218
199
k = key_cache ,
219
200
v = value_cache ,
@@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
228
209
softcap = soft_cap if soft_cap is not None else 0 ,
229
210
)
230
211
231
- if num_blocks <= 2048 :
232
- test_utils = ["test_faketensor" , "test_schema" ]
233
- else :
234
- test_utils = ["test_faketensor" ]
235
-
236
- opcheck (torch .ops .vllm .flash_attn_varlen_func ,
237
- args = tuple (),
238
- kwargs = dict (
239
- q = query ,
240
- k = key_cache ,
241
- v = value_cache ,
242
- cu_seqlens_q = cu_query_lens ,
243
- cu_seqlens_k = cu_kv_lens ,
244
- max_seqlen_q = max_query_len ,
245
- max_seqlen_k = max_kv_len ,
246
- softmax_scale = scale ,
247
- causal = True ,
248
- window_size = window_size ,
249
- block_table = block_tables ,
250
- softcap = soft_cap if soft_cap is not None else 0 ,
251
- ),
252
- test_utils = test_utils )
253
-
254
212
ref_output = ref_paged_attn (
255
213
query = query ,
256
214
key_cache = key_cache ,
0 commit comments