99# Triton 2.1.0
1010@triton .jit
1111def _flash_decoding_fwd_kernel (
12- Q , # [batch_size, head_num, head_dim]
12+ Q , # [batch_size, head_num, q_len(1), head_dim]
1313 KCache , # [num_blocks, num_kv_heads, head_dim, block_size]
1414 VCache , # [num_blocks, num_kv_heads, head_dim, block_size]
1515 block_tables , # [batch_size, max_blocks_per_sequence]
1616 mid_o , # [batch_size, head_num, kv_split_num, head_dim]
1717 mid_o_lse , # [batch_size, head_num, kv_split_num]
18- context_lengths , # [batch_size]
18+ kv_seq_len , # [batch_size]
1919 stride_qt ,
2020 stride_qh ,
21+ stride_ql ,
2122 stride_qd ,
2223 stride_cacheb ,
2324 stride_cacheh ,
@@ -51,7 +52,7 @@ def _flash_decoding_fwd_kernel(
5152 tl .static_assert (BLOCK_KV == BLOCK_SIZE )
5253
5354 # get the current (kv) sequence length from provided context lengths tensor
54- cur_kv_seq_len = tl .load (context_lengths + cur_seq_idx )
55+ cur_kv_seq_len = tl .load (kv_seq_len + cur_seq_idx )
5556
5657 offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
5758 q = tl .load (Q + offsets_q )
@@ -65,7 +66,6 @@ def _flash_decoding_fwd_kernel(
6566 cur_block_id = tl .load (block_table_ptr + cur_bt_start_idx * stride_btb )
6667
6768 if block_start_kv * BLOCK_KV >= cur_kv_seq_len :
68- # TODO might want to remove if-else block?
6969 return
7070
7171 cur_occupied_size = tl .where (
@@ -132,7 +132,7 @@ def _flash_decoding_fwd_reduce_kernel(
132132 mid_o , # [batch_size, head_num, kv_split_num, head_dim]
133133 mid_o_lse , # [batch_size, head_num, kv_split_num]
134134 O , # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
135- context_lengths ,
135+ kv_seq_len ,
136136 stride_mid_ot ,
137137 stride_mid_oh ,
138138 stride_mid_ob ,
@@ -141,6 +141,7 @@ def _flash_decoding_fwd_reduce_kernel(
141141 stride_o_lseh ,
142142 stride_o_lseb ,
143143 stride_ob ,
144+ stride_ol ,
144145 stride_oh ,
145146 stride_od ,
146147 BLOCK_KV : tl .constexpr ,
@@ -149,7 +150,7 @@ def _flash_decoding_fwd_reduce_kernel(
149150 cur_seq_idx = tl .program_id (0 )
150151 cur_head_idx = tl .program_id (1 )
151152
152- cur_kv_seq_len = tl .load (context_lengths + cur_seq_idx )
153+ cur_kv_seq_len = tl .load (kv_seq_len + cur_seq_idx )
153154 offsets_dmodel = tl .arange (0 , HEAD_DIM )
154155
155156 # NOTE currently the block size BLOCK_KV splitting kv is relatively small as we have
@@ -181,97 +182,126 @@ def _flash_decoding_fwd_reduce_kernel(
181182
182183# Decoding Stage
183184# Used with blocked KV Cache (PagedAttention)
184- def flash_decoding_fwd (
185- q : torch .Tensor , # [bsz(e.g.num_tokens), 1, num_heads, head_dim]
186- k_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim, block_size]
187- v_cache : torch .Tensor , # [num_blocks, num_kv_heads, head_dim, block_size]
188- context_lengths : torch .Tensor , # [batch_size]
189- block_tables : torch .Tensor , # [batch_size, max_blocks_per_sequence]
185+ def flash_decoding_attention (
186+ q : torch .Tensor ,
187+ k_cache : torch .Tensor ,
188+ v_cache : torch .Tensor ,
189+ kv_seq_len : torch .Tensor ,
190+ block_tables : torch .Tensor ,
190191 block_size : int ,
191- num_kv_group : int = 1 ,
192+ max_seq_len_in_batch : int = None ,
193+ mid_output : torch .Tensor = None ,
194+ mid_output_lse : torch .Tensor = None ,
195+ sm_scale : int = None ,
196+ kv_group_num : int = 1 ,
192197):
193- bsz , _ , num_heads , head_dim = q .shape
198+ """
199+ Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
200+
201+ Args:
202+ q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
203+ k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
204+ v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
205+ kv_seq_len (torch.Tensor): [batch_size]
206+ records the (kv) sequence lengths incorporating past kv sequence lengths.
207+ block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
208+ max_seq_len_in_batch (int): Maximum sequence length in the batch.
209+ mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
210+ Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
211+ mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
212+ Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
213+ block_size (int): Size of each block in the blocked key/value cache.
214+ num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
215+
216+ Returns:
217+ Output tensor with shape [bsz, num_heads, q_len, head_dim]
218+ """
219+ bsz , num_heads , _ , head_dim = q .shape
194220
195221 assert head_dim in {32 , 64 , 128 , 256 }
196- assert context_lengths .shape [0 ] == block_tables .shape [0 ] == bsz , (
222+ assert kv_seq_len .shape [0 ] == block_tables .shape [0 ] == bsz , (
197223 f"Got incompatible batch size (number of seqs):\n "
198- f" Conext lengths bsz { context_lengths .shape [0 ]} , Block tables bsz { block_tables .shape [0 ]} , "
224+ f" KV seq lengths bsz { kv_seq_len .shape [0 ]} , Block tables bsz { block_tables .shape [0 ]} , "
199225 f"batch size { bsz } "
200226 )
201227 assert k_cache .size (- 1 ) == v_cache .size (- 1 ) == block_size , (
202228 f"Got incompatible block size on kv caches:\n "
203229 f" assigned block_size { block_size } , k_cache block_size { k_cache .size (- 1 )} , "
204230 f"v_cache block_size { v_cache .size (- 1 )} "
205231 )
206- # NOTE `context_lengths` records the (kv) sequence lengths incorporating past kv sequence lengths.
207- bsz = context_lengths .size (0 ) # e.g. the number of seqs
208- max_seq_len = context_lengths .max ().item ()
209- sm_scale = 1.0 / (head_dim ** 0.5 )
210232
211233 # NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
212234 # For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`)
213235 assert block_size in {16 , 32 , 64 , 128 }
214236 BLOCK_KV = block_size
215237
216- kv_max_split_num = (max_seq_len + BLOCK_KV - 1 ) // BLOCK_KV
217- mid_o = torch .zeros (size = (bsz , num_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device )
218- mid_o_lse = torch .zeros (size = (bsz , num_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
219-
220- if q .dim () == 4 :
221- assert q .size (1 ) == 1 , f"q_len is supposed to be 1 but is { q .size (1 )} "
222- q = q .squeeze (1 )
238+ sm_scale = 1.0 / (head_dim ** 0.5 ) if sm_scale is None else sm_scale
239+ max_seq_len_in_batch = kv_seq_len .max ().item () if max_seq_len_in_batch is None else max_seq_len_in_batch
240+ # For compatibility (TODO revise modeling in future)
241+ kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1 ) // BLOCK_KV
242+ mid_output = (
243+ torch .zeros (size = (bsz , num_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device )
244+ if mid_output is None
245+ else mid_output
246+ )
247+ mid_output_lse = (
248+ torch .zeros (size = (bsz , num_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
249+ if mid_output_lse is None
250+ else mid_output_lse
251+ )
223252
224- grid = (bsz , num_heads , triton .cdiv (max_seq_len , BLOCK_KV ))
253+ grid = (triton . next_power_of_2 ( bsz ) , num_heads , triton .cdiv (triton . next_power_of_2 ( max_seq_len_in_batch ) , BLOCK_KV ))
225254 _flash_decoding_fwd_kernel [grid ](
226255 q ,
227256 k_cache ,
228257 v_cache ,
229258 block_tables ,
230- mid_o ,
231- mid_o_lse ,
232- context_lengths ,
259+ mid_output ,
260+ mid_output_lse ,
261+ kv_seq_len ,
233262 q .stride (0 ),
234263 q .stride (1 ),
235264 q .stride (2 ),
265+ q .stride (3 ),
236266 k_cache .stride (0 ),
237267 k_cache .stride (1 ),
238268 k_cache .stride (2 ),
239269 k_cache .stride (3 ),
240270 block_tables .stride (0 ),
241271 block_tables .stride (1 ),
242- mid_o .stride (0 ),
243- mid_o .stride (1 ),
244- mid_o .stride (2 ),
245- mid_o .stride (3 ),
246- mid_o_lse .stride (0 ),
247- mid_o_lse .stride (1 ),
248- mid_o_lse .stride (2 ),
272+ mid_output .stride (0 ),
273+ mid_output .stride (1 ),
274+ mid_output .stride (2 ),
275+ mid_output .stride (3 ),
276+ mid_output_lse .stride (0 ),
277+ mid_output_lse .stride (1 ),
278+ mid_output_lse .stride (2 ),
249279 sm_scale ,
250- KV_GROUPS = num_kv_group ,
280+ KV_GROUPS = kv_group_num ,
251281 BLOCK_KV = block_size ,
252282 BLOCK_SIZE = block_size ,
253283 HEAD_DIM = head_dim ,
254284 )
255285
256- output = torch .zeros_like (q )
257- output = output .view (- 1 , output .size (- 2 ), output .size (- 1 ))
286+ output = torch .empty ((bsz , 1 , num_heads , head_dim ), dtype = q .dtype , device = q .device ) # already overlapped
258287
259288 grid = (bsz , num_heads )
260289 _flash_decoding_fwd_reduce_kernel [grid ](
261- mid_o ,
262- mid_o_lse ,
290+ mid_output ,
291+ mid_output_lse ,
263292 output ,
264- context_lengths ,
265- mid_o .stride (0 ),
266- mid_o .stride (1 ),
267- mid_o .stride (2 ),
268- mid_o .stride (3 ),
269- mid_o_lse .stride (0 ),
270- mid_o_lse .stride (1 ),
271- mid_o_lse .stride (2 ),
293+ kv_seq_len ,
294+ mid_output .stride (0 ),
295+ mid_output .stride (1 ),
296+ mid_output .stride (2 ),
297+ mid_output .stride (3 ),
298+ mid_output_lse .stride (0 ),
299+ mid_output_lse .stride (1 ),
300+ mid_output_lse .stride (2 ),
272301 output .stride (0 ),
273302 output .stride (1 ),
274303 output .stride (2 ),
304+ output .stride (3 ),
275305 BLOCK_KV = block_size ,
276306 HEAD_DIM = head_dim ,
277307 )
0 commit comments