Skip to content

Commit af8359c

Browse files
[hotfix] fix boundary check in batch (#5306)
1 parent c647e00 commit af8359c

File tree

2 files changed

+18
-1
lines changed

2 files changed

+18
-1
lines changed

colossalai/kernel/triton/context_attn_unpad.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ def _fwd_context_paged_attention_kernel(
2222
KCache,
2323
VCache,
2424
BLOCK_TABLES, # [num_seqs, max_blocks_per_sequence]
25+
batch_size,
2526
stride_qt,
2627
stride_qh,
2728
stride_qd,
@@ -49,6 +50,8 @@ def _fwd_context_paged_attention_kernel(
4950
BLOCK_N: tl.constexpr,
5051
):
5152
cur_seq_idx = tl.program_id(0)
53+
if cur_seq_idx >= batch_size:
54+
return
5255
cur_head_idx = tl.program_id(1)
5356
block_start_m = tl.program_id(2) # Br, max_input_len // Block_M
5457
cur_kv_head_idx = cur_head_idx // KV_GROUPS
@@ -217,6 +220,8 @@ def context_attention_unpadded(
217220
assert block_size in {16, 32, 64, 128}
218221
BLOCK_M = BLOCK_N = block_size
219222

223+
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
224+
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
220225
grid = (triton.next_power_of_2(num_seqs), num_heads, triton.cdiv(max_seq_len, BLOCK_M))
221226

222227
_fwd_context_paged_attention_kernel[grid](
@@ -227,6 +232,7 @@ def context_attention_unpadded(
227232
k_cache,
228233
v_cache,
229234
block_tables,
235+
num_seqs,
230236
q.stride(0),
231237
q.stride(1),
232238
q.stride(2),

colossalai/kernel/triton/flash_decoding.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def _flash_decoding_fwd_kernel(
1616
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
1717
mid_o_lse, # [batch_size, head_num, kv_split_num]
1818
kv_seq_len, # [batch_size]
19+
batch_size,
1920
stride_qt,
2021
stride_qh,
2122
stride_qd,
@@ -39,6 +40,8 @@ def _flash_decoding_fwd_kernel(
3940
HEAD_DIM: tl.constexpr,
4041
):
4142
cur_seq_idx = tl.program_id(0)
43+
if cur_seq_idx >= batch_size:
44+
return
4245
cur_head_idx = tl.program_id(1)
4346
block_start_kv = tl.program_id(2) # for splitting k/v
4447

@@ -132,6 +135,7 @@ def _flash_decoding_fwd_reduce_kernel(
132135
mid_o_lse, # [batch_size, head_num, kv_split_num]
133136
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
134137
kv_seq_len,
138+
batch_size,
135139
stride_mid_ot,
136140
stride_mid_oh,
137141
stride_mid_ob,
@@ -147,6 +151,8 @@ def _flash_decoding_fwd_reduce_kernel(
147151
HEAD_DIM: tl.constexpr,
148152
):
149153
cur_seq_idx = tl.program_id(0)
154+
if cur_seq_idx >= batch_size:
155+
return
150156
cur_head_idx = tl.program_id(1)
151157

152158
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
@@ -251,6 +257,8 @@ def flash_decoding_attention(
251257
else mid_output_lse
252258
)
253259

260+
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
261+
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
254262
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
255263
_flash_decoding_fwd_kernel[grid](
256264
q,
@@ -260,6 +268,7 @@ def flash_decoding_attention(
260268
mid_output,
261269
mid_output_lse,
262270
kv_seq_len,
271+
bsz,
263272
q.stride(0),
264273
q.stride(1),
265274
q.stride(2),
@@ -285,12 +294,14 @@ def flash_decoding_attention(
285294

286295
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device) # already overlapped
287296

288-
grid = (bsz, num_heads)
297+
grid = (triton.next_power_of_2(bsz), num_heads)
298+
289299
_flash_decoding_fwd_reduce_kernel[grid](
290300
mid_output,
291301
mid_output_lse,
292302
output,
293303
kv_seq_len,
304+
bsz,
294305
mid_output.stride(0),
295306
mid_output.stride(1),
296307
mid_output.stride(2),

0 commit comments

Comments
 (0)