Skip to content

Commit 9aaf14c

Browse files
authored
[misc] add forward context for attention (#9029)
1 parent 63e3993 commit 9aaf14c

File tree

8 files changed

+250
-334
lines changed

8 files changed

+250
-334
lines changed

tests/kernels/test_flash_attn.py

+7-49
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,9 @@
33
import pytest
44
import torch
55

6-
import vllm.attention.backends.flash_attn # noqa: F401
7-
from tests.kernels.utils import opcheck
86
from vllm.utils import seed_everything
7+
from vllm.vllm_flash_attn import (flash_attn_varlen_func,
8+
flash_attn_with_kvcache)
99

1010
NUM_HEADS = [(4, 4), (8, 2), (16, 2)]
1111
HEAD_SIZES = [128, 256]
@@ -112,36 +112,17 @@ def test_flash_attn_with_paged_kv(
112112
(num_seqs, max_num_blocks_per_seq),
113113
dtype=torch.int32)
114114

115-
output = torch.ops.vllm.flash_attn_with_kvcache(
116-
decode_query=query.unsqueeze(1),
117-
key_cache=key_cache,
118-
value_cache=value_cache,
115+
output = flash_attn_with_kvcache(
116+
q=query.unsqueeze(1),
117+
k_cache=key_cache,
118+
v_cache=value_cache,
119119
softmax_scale=scale,
120120
causal=True,
121121
block_table=block_tables,
122122
cache_seqlens=kv_lens_tensor,
123123
softcap=soft_cap if soft_cap is not None else 0,
124124
).squeeze(1)
125125

126-
if num_blocks <= 2048:
127-
test_utils = ["test_faketensor", "test_schema"]
128-
else:
129-
test_utils = ["test_faketensor"]
130-
131-
opcheck(torch.ops.vllm.flash_attn_with_kvcache,
132-
args=tuple(),
133-
kwargs=dict(
134-
decode_query=query.unsqueeze(1),
135-
key_cache=key_cache,
136-
value_cache=value_cache,
137-
softmax_scale=scale,
138-
causal=True,
139-
block_table=block_tables,
140-
cache_seqlens=kv_lens_tensor,
141-
softcap=soft_cap if soft_cap is not None else 0,
142-
),
143-
test_utils=test_utils)
144-
145126
ref_output = ref_paged_attn(
146127
query=query,
147128
key_cache=key_cache,
@@ -213,7 +194,7 @@ def test_varlen_with_paged_kv(
213194
(num_seqs, max_num_blocks_per_seq),
214195
dtype=torch.int32)
215196

216-
output = torch.ops.vllm.flash_attn_varlen_func(
197+
output = flash_attn_varlen_func(
217198
q=query,
218199
k=key_cache,
219200
v=value_cache,
@@ -228,29 +209,6 @@ def test_varlen_with_paged_kv(
228209
softcap=soft_cap if soft_cap is not None else 0,
229210
)
230211

231-
if num_blocks <= 2048:
232-
test_utils = ["test_faketensor", "test_schema"]
233-
else:
234-
test_utils = ["test_faketensor"]
235-
236-
opcheck(torch.ops.vllm.flash_attn_varlen_func,
237-
args=tuple(),
238-
kwargs=dict(
239-
q=query,
240-
k=key_cache,
241-
v=value_cache,
242-
cu_seqlens_q=cu_query_lens,
243-
cu_seqlens_k=cu_kv_lens,
244-
max_seqlen_q=max_query_len,
245-
max_seqlen_k=max_kv_len,
246-
softmax_scale=scale,
247-
causal=True,
248-
window_size=window_size,
249-
block_table=block_tables,
250-
softcap=soft_cap if soft_cap is not None else 0,
251-
),
252-
test_utils=test_utils)
253-
254212
ref_output = ref_paged_attn(
255213
query=query,
256214
key_cache=key_cache,

0 commit comments

Comments
 (0)