@@ -16,6 +16,7 @@ def _flash_decoding_fwd_kernel(
1616 mid_o , # [batch_size, head_num, kv_split_num, head_dim]
1717 mid_o_lse , # [batch_size, head_num, kv_split_num]
1818 kv_seq_len , # [batch_size]
19+ batch_size ,
1920 stride_qt ,
2021 stride_qh ,
2122 stride_qd ,
@@ -39,6 +40,8 @@ def _flash_decoding_fwd_kernel(
3940 HEAD_DIM : tl .constexpr ,
4041):
4142 cur_seq_idx = tl .program_id (0 )
43+ if cur_seq_idx >= batch_size :
44+ return
4245 cur_head_idx = tl .program_id (1 )
4346 block_start_kv = tl .program_id (2 ) # for splitting k/v
4447
@@ -132,6 +135,7 @@ def _flash_decoding_fwd_reduce_kernel(
132135 mid_o_lse , # [batch_size, head_num, kv_split_num]
133136 O , # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
134137 kv_seq_len ,
138+ batch_size ,
135139 stride_mid_ot ,
136140 stride_mid_oh ,
137141 stride_mid_ob ,
@@ -147,6 +151,8 @@ def _flash_decoding_fwd_reduce_kernel(
147151 HEAD_DIM : tl .constexpr ,
148152):
149153 cur_seq_idx = tl .program_id (0 )
154+ if cur_seq_idx >= batch_size :
155+ return
150156 cur_head_idx = tl .program_id (1 )
151157
152158 cur_kv_seq_len = tl .load (kv_seq_len + cur_seq_idx )
@@ -251,6 +257,8 @@ def flash_decoding_attention(
251257 else mid_output_lse
252258 )
253259
260+ # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
261+ # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
254262 grid = (triton .next_power_of_2 (bsz ), num_heads , triton .cdiv (triton .next_power_of_2 (max_seq_len_in_batch ), BLOCK_KV ))
255263 _flash_decoding_fwd_kernel [grid ](
256264 q ,
@@ -260,6 +268,7 @@ def flash_decoding_attention(
260268 mid_output ,
261269 mid_output_lse ,
262270 kv_seq_len ,
271+ bsz ,
263272 q .stride (0 ),
264273 q .stride (1 ),
265274 q .stride (2 ),
@@ -285,12 +294,14 @@ def flash_decoding_attention(
285294
286295 output = torch .empty ((bsz , 1 , num_heads , head_dim ), dtype = q .dtype , device = q .device ) # already overlapped
287296
288- grid = (bsz , num_heads )
297+ grid = (triton .next_power_of_2 (bsz ), num_heads )
298+
289299 _flash_decoding_fwd_reduce_kernel [grid ](
290300 mid_output ,
291301 mid_output_lse ,
292302 output ,
293303 kv_seq_len ,
304+ bsz ,
294305 mid_output .stride (0 ),
295306 mid_output .stride (1 ),
296307 mid_output .stride (2 ),
0 commit comments