Skip to content

Commit c9f4f2e

Browse files
CyCle1024yao-fengchen
authored andcommitted
feat: change infer_ext ops function param order (#2)
1 parent b50dbe2 commit c9f4f2e

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lmdeploy/pytorch/kernels/ascend/paged_attention_fwd.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ def flash_context_attention(
2626
for i in range(batch):
2727
if torch.equal(q_seq_len[i], kv_seq_len[i]):
2828
ext_ops.context_attention(
29-
attn_output,
3029
query_states,
3130
key_states,
3231
value_states,
@@ -35,13 +34,13 @@ def flash_context_attention(
3534
num_q_heads,
3635
num_kv_heads,
3736
context.attention_mask[i:i + 1],
37+
attn_output=attn_output,
3838
)
3939
else:
4040
key_cache = key_cache.reshape(1, kv_cache_len, num_kv_heads * dim)
4141
value_cache = value_cache.reshape(1, kv_cache_len,
4242
num_kv_heads * dim)
4343
ext_ops.paged_prefill_attention(
44-
attn_output,
4544
query_states,
4645
key_cache,
4746
value_cache,
@@ -53,14 +52,14 @@ def flash_context_attention(
5352
num_q_heads,
5453
num_kv_heads,
5554
context.attention_mask[i:i + 1],
55+
attn_output=attn_output,
5656
)
5757

5858

5959
def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6060
block_offsets, block_size):
6161
num_kv_heads, num_q_heads = k_cache.shape[1], q.shape[1]
6262
ext_ops.paged_decode_attention(
63-
attn_output.view(q.shape),
6463
q,
6564
k_cache,
6665
v_cache,
@@ -69,6 +68,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6968
kv_seq_len,
7069
num_q_heads,
7170
num_kv_heads,
71+
attn_output=attn_output.view(q.shape),
7272
)
7373

7474

0 commit comments

Comments
 (0)