Skip to content

Commit 558db80

Browse files
authored
[V1][Kernel] Refactor the prefix_prefill kernel so that the caller no longer has to pass in the context lengths (#13095)
1 parent e109e59 commit 558db80

File tree

6 files changed

+12
-31
lines changed

6 files changed

+12
-31
lines changed

tests/kernels/test_prefix_prefill.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def test_contexted_kv_attention(
100100
BS, max_block_per_request)
101101
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
102102
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
103-
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
103+
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
104104
dtype=torch.long),
105105
dim=0)
106106
max_input_len = MAX_SEQ_LEN
@@ -154,7 +154,6 @@ def test_contexted_kv_attention(
154154
block_table,
155155
b_start_loc,
156156
b_seq_len,
157-
b_ctx_len,
158157
max_input_len,
159158
k_scale,
160159
v_scale,
@@ -171,7 +170,6 @@ def test_contexted_kv_attention(
171170
block_table,
172171
b_start_loc,
173172
b_seq_len,
174-
b_ctx_len,
175173
max_input_len,
176174
k_scale,
177175
v_scale,
@@ -333,7 +331,7 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
333331
BS, max_block_per_request)
334332
b_seq_len = torch.tensor(seq_lens, dtype=torch.long)
335333
b_ctx_len = torch.tensor(ctx_lens, dtype=torch.long)
336-
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens[:-1],
334+
b_start_loc = torch.cumsum(torch.tensor([0] + query_lens,
337335
dtype=torch.long),
338336
dim=0)
339337
max_input_len = MAX_SEQ_LEN
@@ -387,7 +385,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
387385
block_table,
388386
b_start_loc,
389387
b_seq_len,
390-
b_ctx_len,
391388
max_input_len,
392389
k_scale,
393390
v_scale,
@@ -404,7 +401,6 @@ def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
404401
block_table,
405402
b_start_loc,
406403
b_seq_len,
407-
b_ctx_len,
408404
max_input_len,
409405
k_scale,
410406
v_scale,

vllm/attention/backends/rocm_flash_attn.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -753,7 +753,6 @@ def forward(
753753
prefill_meta.block_tables,
754754
prefill_meta.query_start_loc,
755755
prefill_meta.seq_lens_tensor,
756-
prefill_meta.context_lens_tensor,
757756
prefill_meta.max_query_len,
758757
self.alibi_slopes,
759758
self.sliding_window[0],

vllm/attention/backends/xformers.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,6 @@ def forward(
580580
prefill_meta.block_tables,
581581
prefill_meta.query_start_loc,
582582
prefill_meta.seq_lens_tensor,
583-
prefill_meta.context_lens_tensor,
584583
prefill_meta.max_query_len,
585584
self.alibi_slopes,
586585
self.sliding_window,

vllm/attention/ops/paged_attn.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,6 @@ def forward_prefix(
202202
block_tables: torch.Tensor,
203203
query_start_loc: torch.Tensor,
204204
seq_lens_tensor: torch.Tensor,
205-
context_lens: torch.Tensor,
206205
max_query_len: int,
207206
alibi_slopes: Optional[torch.Tensor],
208207
sliding_window: Optional[int],
@@ -220,9 +219,8 @@ def forward_prefix(
220219
value_cache,
221220
block_tables,
222221
# query_start_loc is (batch_size + 1,)
223-
query_start_loc[:-1],
222+
query_start_loc,
224223
seq_lens_tensor,
225-
context_lens,
226224
max_query_len,
227225
k_scale,
228226
v_scale,

vllm/attention/ops/prefix_prefill.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ def _fwd_kernel(
3131
v_scale,
3232
B_Start_Loc,
3333
B_Seqlen,
34-
B_Ctxlen,
3534
block_size,
3635
x,
3736
Out,
@@ -72,10 +71,12 @@ def _fwd_kernel(
7271

7372
cur_kv_head = cur_head // num_queries_per_kv
7473

75-
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
7674
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
7775
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
78-
cur_batch_query_len = cur_batch_seq_len - cur_batch_ctx_len
76+
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
77+
cur_batch_query_len = (cur_batch_in_all_stop_index -
78+
cur_batch_in_all_start_index)
79+
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
7980

8081
# start position inside of the query
8182
# generally, N goes over kv, while M goes over query_len
@@ -466,7 +467,6 @@ def _fwd_kernel_alibi(
466467
v_scale,
467468
B_Start_Loc,
468469
B_Seqlen,
469-
B_Ctxlen,
470470
Alibi_slopes,
471471
block_size,
472472
x,
@@ -511,9 +511,12 @@ def _fwd_kernel_alibi(
511511
# cur_batch_seq_len: the length of prompts
512512
# cur_batch_ctx_len: the length of prefix
513513
# cur_batch_in_all_start_index: the start id of the dim=0
514-
cur_batch_ctx_len = tl.load(B_Ctxlen + cur_batch)
515514
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
516515
cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch)
516+
cur_batch_in_all_stop_index = tl.load(B_Start_Loc + cur_batch + 1)
517+
cur_batch_query_len = (cur_batch_in_all_stop_index -
518+
cur_batch_in_all_start_index)
519+
cur_batch_ctx_len = cur_batch_seq_len - cur_batch_query_len
517520

518521
block_start_loc = BLOCK_M * start_m
519522

@@ -713,7 +716,6 @@ def context_attention_fwd(q,
713716
b_loc,
714717
b_start_loc,
715718
b_seq_len,
716-
b_ctx_len,
717719
max_input_len,
718720
k_scale: torch.Tensor,
719721
v_scale: torch.Tensor,
@@ -765,6 +767,7 @@ def context_attention_fwd(q,
765767
batch, head = b_seq_len.shape[0], q.shape[1]
766768
num_queries_per_kv = q.shape[1] // k.shape[1]
767769

770+
assert batch + 1 == len(b_start_loc)
768771
grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head,
769772

770773
# 0 means "disable"
@@ -784,7 +787,6 @@ def context_attention_fwd(q,
784787
v_scale,
785788
b_start_loc,
786789
b_seq_len,
787-
b_ctx_len,
788790
alibi_slopes,
789791
v_cache.shape[3],
790792
k_cache.shape[4],
@@ -838,7 +840,6 @@ def context_attention_fwd(q,
838840
v_scale,
839841
b_start_loc,
840842
b_seq_len,
841-
b_ctx_len,
842843
v_cache.shape[3],
843844
k_cache.shape[4],
844845
o,

vllm/v1/attention/backends/rocm_attn.py

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -150,17 +150,6 @@ def forward(
150150
layer._v_scale,
151151
)
152152

153-
# TODO(sage): Refactor the context_attention_fwd kernel so that this
154-
# overhead can be removed
155-
context_lens = torch.empty_like(attn_metadata.seq_lens)
156-
batch_size = len(attn_metadata.query_start_loc) - 1
157-
assert len(context_lens) == batch_size
158-
for i in range(batch_size):
159-
query_start = attn_metadata.query_start_loc[i]
160-
query_end = attn_metadata.query_start_loc[i + 1]
161-
context_lens[i] = attn_metadata.seq_lens[i] - (query_end -
162-
query_start)
163-
164153
# Compute attention and update output up to `num_actual_tokens`.
165154
context_attention_fwd(q=query[:num_actual_tokens],
166155
k=key[:num_actual_tokens],
@@ -172,7 +161,6 @@ def forward(
172161
b_loc=attn_metadata.block_table,
173162
b_start_loc=attn_metadata.query_start_loc,
174163
b_seq_len=attn_metadata.seq_lens,
175-
b_ctx_len=context_lens,
176164
max_input_len=attn_metadata.max_query_len,
177165
k_scale=layer._k_scale,
178166
v_scale=layer._v_scale,

0 commit comments

Comments
 (0)