@@ -26,7 +26,6 @@ def flash_context_attention(
2626 for i in range (batch ):
2727 if torch .equal (q_seq_len [i ], kv_seq_len [i ]):
2828 ext_ops .context_attention (
29- attn_output ,
3029 query_states ,
3130 key_states ,
3231 value_states ,
@@ -35,13 +34,13 @@ def flash_context_attention(
3534 num_q_heads ,
3635 num_kv_heads ,
3736 context .attention_mask [i :i + 1 ],
37+ attn_output = attn_output ,
3838 )
3939 else :
4040 key_cache = key_cache .reshape (1 , kv_cache_len , num_kv_heads * dim )
4141 value_cache = value_cache .reshape (1 , kv_cache_len ,
4242 num_kv_heads * dim )
4343 ext_ops .paged_prefill_attention (
44- attn_output ,
4544 query_states ,
4645 key_cache ,
4746 value_cache ,
@@ -53,14 +52,14 @@ def flash_context_attention(
5352 num_q_heads ,
5453 num_kv_heads ,
5554 context .attention_mask [i :i + 1 ],
55+ attn_output = attn_output ,
5656 )
5757
5858
5959def paged_token_attention (q , k_cache , v_cache , attn_output , kv_seq_len ,
6060 block_offsets , block_size ):
6161 num_kv_heads , num_q_heads = k_cache .shape [1 ], q .shape [1 ]
6262 ext_ops .paged_decode_attention (
63- attn_output .view (q .shape ),
6463 q ,
6564 k_cache ,
6665 v_cache ,
@@ -69,6 +68,7 @@ def paged_token_attention(q, k_cache, v_cache, attn_output, kv_seq_len,
6968 kv_seq_len ,
7069 num_q_heads ,
7170 num_kv_heads ,
71+ attn_output = attn_output .view (q .shape ),
7272 )
7373
7474
0 commit comments