Skip to content

Commit 0a6ea7d

Browse files
committed
fix calling of triton kernel in modeling
1 parent 0ec90af commit 0a6ea7d

File tree

3 files changed

+31
-8
lines changed

3 files changed

+31
-8
lines changed

colossalai/inference/modeling/models/llama.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from colossalai.inference.modeling.layers.attention import PagedAttention
88
from colossalai.inference.struct import BatchInfo
9-
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_fwd
9+
from colossalai.kernel.triton import context_attention_unpadded, copy_kv_to_blocked_cache, flash_decoding_attention
1010
from colossalai.logging import get_dist_logger
1111

1212
from flash_attn.bert_padding import index_first_axis, pad_input # noqa
@@ -209,7 +209,15 @@ def llama_attn_forward(
209209
if HAS_TRITON:
210210
copy_kv_to_blocked_cache(key_states, k_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
211211
copy_kv_to_blocked_cache(value_states, v_cache, kv_lengths=sequence_lengths, block_tables=block_tables)
212-
attn_output = flash_decoding_fwd(query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size)
212+
# TODO Add dummy transpose and squeeze on in/out tensors of the triton flash decoding kernel
213+
# in order to maintain compatibility. This part as well as the logics of handling query_states and attn_output
214+
# should be revised, as we could see in previous part of `llama_attn_forward` we still have
215+
# redundant transpose and the in/out of torch- and triton-version decoding kernel are not consistent.
216+
query_states = query_states.transpose(1, 2)
217+
attn_output = flash_decoding_attention(
218+
query_states, k_cache, v_cache, sequence_lengths, block_tables, block_size
219+
)
220+
attn_output = attn_output.squeeze(1)
213221
else:
214222
attn_output = PagedAttention.pad_decoding_forward(
215223
query_states, key_states, value_states, k_cache, v_cache, sequence_lengths, block_tables, attention_mask

colossalai/kernel/triton/flash_decoding.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -188,11 +188,11 @@ def flash_decoding_attention(
188188
v_cache: torch.Tensor,
189189
kv_seq_len: torch.Tensor,
190190
block_tables: torch.Tensor,
191-
max_seq_len_in_batch: int,
192-
mid_output: torch.Tensor,
193-
mid_output_lse: torch.Tensor,
194191
block_size: int,
195-
sm_scale: int,
192+
max_seq_len_in_batch: int = None,
193+
mid_output: torch.Tensor = None,
194+
mid_output_lse: torch.Tensor = None,
195+
sm_scale: int = None,
196196
kv_group_num: int = 1,
197197
):
198198
"""
@@ -236,6 +236,21 @@ def flash_decoding_attention(
236236
assert block_size in {16, 32, 64, 128}
237237
BLOCK_KV = block_size
238238

239+
sm_scale = 1.0 / (head_dim**0.5) if sm_scale is None else sm_scale
240+
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
241+
# For compatibility (TODO revise modeling in future)
242+
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
243+
mid_output = (
244+
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
245+
if mid_output is None
246+
else mid_output
247+
)
248+
mid_output_lse = (
249+
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
250+
if mid_output_lse is None
251+
else mid_output_lse
252+
)
253+
239254
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
240255
_flash_decoding_fwd_kernel[grid](
241256
q,

tests/test_infer_ops/triton/test_decoding_attn.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -93,10 +93,10 @@ def test_flash_decoding(
9393
v_cache,
9494
context_lengths,
9595
block_tables,
96+
block_size,
9697
max_seq_len_in_b,
9798
mid_output,
9899
mid_output_lse,
99-
block_size=block_size,
100100
sm_scale=sm_scale,
101101
kv_group_num=kv_group_num,
102102
) # [bsz, 1, num_heads, head_dim]
@@ -221,10 +221,10 @@ def bench_kernel(
221221
v_cache,
222222
kv_lengths,
223223
block_tables,
224+
block_size,
224225
max_seq_len_in_b,
225226
mid_output,
226227
mid_output_lse,
227-
block_size=block_size,
228228
sm_scale=sm_scale,
229229
kv_group_num=kv_group_num,
230230
) # [bsz, 1, num_heads, head_dim]

0 commit comments

Comments
 (0)