99# Triton 2.1.0
1010@triton .jit
1111def _flash_decoding_fwd_kernel (
12- Q , # [batch_size, head_num, q_len(1) , head_dim]
12+ Q , # [batch_size * q_len, head_num , head_dim]
1313 KCache , # [num_blocks, num_kv_heads, block_size, head_dim]
1414 VCache , # [num_blocks, num_kv_heads, block_size, head_dim]
1515 block_tables , # [batch_size, max_blocks_per_sequence]
16- mid_o , # [batch_size, head_num, kv_split_num, head_dim]
17- mid_o_lse , # [batch_size, head_num, kv_split_num]
16+ mid_o , # [batch_size * q_len , head_num, kv_split_num, head_dim]
17+ mid_o_lse , # [batch_size * q_len , head_num, kv_split_num]
1818 kv_seq_len , # [batch_size]
19+ q_len ,
1920 batch_size ,
2021 stride_qt ,
2122 stride_qh ,
@@ -39,44 +40,37 @@ def _flash_decoding_fwd_kernel(
3940 BLOCK_SIZE : tl .constexpr ,
4041 HEAD_DIM : tl .constexpr ,
4142):
42- cur_seq_idx = tl .program_id (0 )
43+ cur_token_idx = tl .program_id (0 )
44+ cur_seq_idx = cur_token_idx // q_len
4345 if cur_seq_idx >= batch_size :
4446 return
4547 cur_head_idx = tl .program_id (1 )
4648 block_start_kv = tl .program_id (2 ) # for splitting k/v
4749
48- cur_kv_head_idx = cur_head_idx // KV_GROUPS
49- offsets_dmodel = tl .arange (0 , HEAD_DIM )
50-
5150 # NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
5251 # TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
5352 # and then support calculating multiple kv cache blocks on an instance
5453 tl .static_assert (BLOCK_KV == BLOCK_SIZE )
55-
56- # get the current (kv) sequence length from provided context lengths tensor
54+ # get the current (kv) sequence length
5755 cur_kv_seq_len = tl .load (kv_seq_len + cur_seq_idx )
56+ if block_start_kv * BLOCK_KV >= cur_kv_seq_len :
57+ return
5858
59- offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
59+ offsets_dmodel = tl .arange (0 , HEAD_DIM )
60+ offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
6061 q = tl .load (Q + offsets_q )
61-
6262 # block table for the current sequence
6363 block_table_ptr = block_tables + cur_seq_idx * stride_bts
64-
65- # actually current block table current block start idx
6664 # cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
67- cur_bt_start_idx = block_start_kv
68- cur_block_id = tl .load (block_table_ptr + cur_bt_start_idx * stride_btb )
69-
70- if block_start_kv * BLOCK_KV >= cur_kv_seq_len :
71- return
72-
65+ # cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
66+ cur_block_id = tl .load (block_table_ptr + block_start_kv * stride_btb )
7367 cur_occupied_size = tl .where (
7468 (block_start_kv + 1 ) * BLOCK_SIZE <= cur_kv_seq_len , BLOCK_SIZE , cur_kv_seq_len - block_start_kv * BLOCK_SIZE
7569 )
7670 tl .device_assert (cur_occupied_size >= 0 )
7771
72+ cur_kv_head_idx = cur_head_idx // KV_GROUPS
7873 offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
79-
8074 K_block_ptr = tl .make_block_ptr (
8175 base = KCache + offset_kvcache ,
8276 shape = (cur_occupied_size , HEAD_DIM ),
@@ -115,14 +109,14 @@ def _flash_decoding_fwd_kernel(
115109 acc = acc / l
116110
117111 offsets_mid_o = (
118- cur_seq_idx * stride_mid_ot
112+ cur_token_idx * stride_mid_ot
119113 + cur_head_idx * stride_mid_oh
120114 + block_start_kv * stride_mid_ob
121115 + offsets_dmodel * stride_mid_od
122116 )
123117 tl .store (mid_o + offsets_mid_o , acc )
124118 offsets_mid_o_lse = (
125- cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
119+ cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
126120 )
127121 # logsumexp L^(j) = m^(j) + log(l^(j))
128122 tl .store (mid_o_lse + offsets_mid_o_lse , m + tl .log (l ))
@@ -135,6 +129,7 @@ def _flash_decoding_fwd_reduce_kernel(
135129 mid_o_lse , # [batch_size, head_num, kv_split_num]
136130 O , # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
137131 kv_seq_len ,
132+ q_len ,
138133 batch_size ,
139134 stride_mid_ot ,
140135 stride_mid_oh ,
@@ -149,7 +144,8 @@ def _flash_decoding_fwd_reduce_kernel(
149144 BLOCK_KV : tl .constexpr ,
150145 HEAD_DIM : tl .constexpr ,
151146):
152- cur_seq_idx = tl .program_id (0 )
147+ cur_token_idx = tl .program_id (0 )
148+ cur_seq_idx = cur_token_idx // q_len
153149 if cur_seq_idx >= batch_size :
154150 return
155151 cur_head_idx = tl .program_id (1 )
@@ -164,8 +160,8 @@ def _flash_decoding_fwd_reduce_kernel(
164160 l = 0.0 # sum exp
165161 acc = tl .zeros ([HEAD_DIM ], dtype = tl .float32 )
166162
167- offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
168- offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh
163+ offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
164+ offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh
169165 for block_i in range (0 , kv_split_num , 1 ):
170166 mid_o_block = tl .load (mid_o + offsets_mid_o + block_i * stride_mid_ob )
171167 lse = tl .load (mid_o_lse + offset_mid_lse + block_i * stride_o_lseb )
@@ -179,7 +175,7 @@ def _flash_decoding_fwd_reduce_kernel(
179175 m_i = m_ij
180176
181177 acc = acc / l
182- offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
178+ offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
183179 tl .store (O + offsets_O , acc .to (O .type .element_ty ))
184180 return
185181
@@ -199,32 +195,40 @@ def flash_decoding_attention(
199195 mid_output_lse : torch .Tensor = None ,
200196 sm_scale : int = None ,
201197 kv_group_num : int = 1 ,
198+ q_len : int = 1 ,
202199):
203200 """
204201 Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
205202
206203 Args:
207- q (torch.Tensor): [bsz, num_heads, head_dim]
204+ q (torch.Tensor): [bsz * q_len, num_heads, head_dim]
205+ q_len > 1 only for verification process in speculative-decoding.
208206 k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
209207 v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
210208 kv_seq_len (torch.Tensor): [batch_size]
211209 records the (kv) sequence lengths incorporating past kv sequence lengths.
212210 block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
213211 max_seq_len_in_batch (int): Maximum sequence length in the batch.
214212 output (torch.Tensor): [bsz, num_heads * head_dim]
215- mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
213+ mid_output (torch.Tensor): [max_bsz * q_len , num_heads, kv_max_split_num, head_dim]
216214 Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
217- mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
215+ q_len > 1 only for verification process in speculative-decoding.
216+ mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
218217 Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
218+ q_len > 1 only for verification process in speculative-decoding.
219219 block_size (int): Size of each block in the blocked key/value cache.
220220 num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
221+ q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
222+ Defaults to 1.
221223
222224 Returns:
223- Output tensor with shape [bsz, num_heads * head_dim]
225+ Output tensor with shape [bsz * q_len , num_heads * head_dim]
224226 """
225227 q = q .squeeze () if q .dim () == 4 else q
226228 assert q .dim () == 3 , f"Incompatible q dim: { q .dim ()} "
227- bsz , num_heads , head_dim = q .shape
229+ n_tokens , num_heads , head_dim = q .shape
230+ assert n_tokens % q_len == 0 , "Invalid q_len"
231+ bsz = n_tokens // q_len
228232
229233 assert head_dim in {32 , 64 , 128 , 256 }
230234 assert kv_seq_len .shape [0 ] == block_tables .shape [0 ] == bsz , (
@@ -247,22 +251,31 @@ def flash_decoding_attention(
247251 max_seq_len_in_batch = kv_seq_len .max ().item () if max_seq_len_in_batch is None else max_seq_len_in_batch
248252 # For compatibility (TODO revise modeling in future)
249253 kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1 ) // BLOCK_KV
250- mid_output = (
251- torch .zeros (size = (bsz , num_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device )
252- if mid_output is None
253- else mid_output
254- )
255- mid_output_lse = (
256- torch .zeros (size = (bsz , num_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
257- if mid_output_lse is None
258- else mid_output_lse
259- )
254+
255+ if mid_output is None :
256+ mid_output = torch .empty (
257+ (bsz * q_len , num_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device
258+ )
259+ if mid_output_lse is None :
260+ mid_output_lse = torch .empty ((bsz * q_len , num_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
261+ if output is None :
262+ # A hack to prevent `view` operation in modeling
263+ output = torch .empty ((bsz * q_len , num_heads * head_dim ), dtype = q .dtype , device = q .device )
264+
265+ assert (
266+ mid_output .size (2 ) == mid_output_lse .size (2 ) >= kv_max_split_num
267+ ), "Incompatible kv split number of intermediate output tensors"
268+ assert (
269+ mid_output .size (0 ) == mid_output_lse .size (0 ) >= output .size (0 ) == n_tokens
270+ ), f"Incompatible first dimension of output tensors"
260271
261272 # NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
262273 # To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
263- grid = (triton .next_power_of_2 (bsz ), num_heads , triton .cdiv (triton .next_power_of_2 (max_seq_len_in_batch ), BLOCK_KV ))
264- output = torch .empty ((bsz , num_heads * head_dim ), dtype = q .dtype , device = q .device ) if output is None else output
265-
274+ grid = (
275+ triton .next_power_of_2 (bsz * q_len ),
276+ num_heads ,
277+ triton .cdiv (triton .next_power_of_2 (max_seq_len_in_batch ), BLOCK_KV ),
278+ )
266279 _flash_decoding_fwd_kernel [grid ](
267280 q ,
268281 k_cache ,
@@ -271,6 +284,7 @@ def flash_decoding_attention(
271284 mid_output ,
272285 mid_output_lse ,
273286 kv_seq_len ,
287+ q_len ,
274288 bsz ,
275289 q .stride (0 ),
276290 q .stride (1 ),
@@ -295,13 +309,13 @@ def flash_decoding_attention(
295309 HEAD_DIM = head_dim ,
296310 )
297311
298- grid = (triton .next_power_of_2 (bsz ), num_heads )
299-
312+ grid = (triton .next_power_of_2 (bsz * q_len ), num_heads )
300313 _flash_decoding_fwd_reduce_kernel [grid ](
301314 mid_output ,
302315 mid_output_lse ,
303316 output ,
304317 kv_seq_len ,
318+ q_len ,
305319 bsz ,
306320 mid_output .stride (0 ),
307321 mid_output .stride (1 ),
0 commit comments