Skip to content

Commit 2d62aca

Browse files
[Infer] Revise and Adapt Triton Kernels for Spec-Dec (#5401)
* [Infer/Fix] Fix Dependency in test - RMSNorm kernel (#5399) fix dependency in pytest * resolve conflicts for revising flash-attn * adapt kv cache copy kernel for spec-dec * fix seqlen-n kvcache copy kernel/tests * test kvcache copy - use torch.equal * add assertions * (trivial) comment out
1 parent bc1da87 commit 2d62aca

File tree

7 files changed

+305
-152
lines changed

7 files changed

+305
-152
lines changed

colossalai/kernel/triton/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .flash_decoding import flash_decoding_attention
1313
from .fused_rotary_embedding import fused_rotary_embedding
1414
from .gptq_triton import gptq_fused_linear_triton
15-
from .kvcache_copy import copy_kv_to_blocked_cache
15+
from .kvcache_copy import copy_k_to_blocked_cache, copy_kv_to_blocked_cache
1616
from .no_pad_rotary_embedding import decoding_fused_rotary_embedding, rotary_embedding
1717
from .rms_layernorm import rms_layernorm
1818
from .rotary_cache_copy import get_xine_cache
@@ -21,6 +21,7 @@
2121
__all__ = [
2222
"context_attention_unpadded",
2323
"flash_decoding_attention",
24+
"copy_k_to_blocked_cache",
2425
"copy_kv_to_blocked_cache",
2526
"softmax",
2627
"rms_layernorm",

colossalai/kernel/triton/flash_decoding.py

Lines changed: 60 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,14 @@
99
# Triton 2.1.0
1010
@triton.jit
1111
def _flash_decoding_fwd_kernel(
12-
Q, # [batch_size, head_num, q_len(1), head_dim]
12+
Q, # [batch_size * q_len, head_num, head_dim]
1313
KCache, # [num_blocks, num_kv_heads, block_size, head_dim]
1414
VCache, # [num_blocks, num_kv_heads, block_size, head_dim]
1515
block_tables, # [batch_size, max_blocks_per_sequence]
16-
mid_o, # [batch_size, head_num, kv_split_num, head_dim]
17-
mid_o_lse, # [batch_size, head_num, kv_split_num]
16+
mid_o, # [batch_size * q_len, head_num, kv_split_num, head_dim]
17+
mid_o_lse, # [batch_size * q_len, head_num, kv_split_num]
1818
kv_seq_len, # [batch_size]
19+
q_len,
1920
batch_size,
2021
stride_qt,
2122
stride_qh,
@@ -39,44 +40,37 @@ def _flash_decoding_fwd_kernel(
3940
BLOCK_SIZE: tl.constexpr,
4041
HEAD_DIM: tl.constexpr,
4142
):
42-
cur_seq_idx = tl.program_id(0)
43+
cur_token_idx = tl.program_id(0)
44+
cur_seq_idx = cur_token_idx // q_len
4345
if cur_seq_idx >= batch_size:
4446
return
4547
cur_head_idx = tl.program_id(1)
4648
block_start_kv = tl.program_id(2) # for splitting k/v
4749

48-
cur_kv_head_idx = cur_head_idx // KV_GROUPS
49-
offsets_dmodel = tl.arange(0, HEAD_DIM)
50-
5150
# NOTE It requires BLOCK_KV and BLOCK_SIZE to be the same
5251
# TODO might want to replace with BLOCK_KV % BLOCK_SIZE == 0 (optimize BLOCK_KV as multiple of BLOCK_SIZE)
5352
# and then support calculating multiple kv cache blocks on an instance
5453
tl.static_assert(BLOCK_KV == BLOCK_SIZE)
55-
56-
# get the current (kv) sequence length from provided context lengths tensor
54+
# get the current (kv) sequence length
5755
cur_kv_seq_len = tl.load(kv_seq_len + cur_seq_idx)
56+
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
57+
return
5858

59-
offsets_q = cur_seq_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
59+
offsets_dmodel = tl.arange(0, HEAD_DIM)
60+
offsets_q = cur_token_idx * stride_qt + cur_head_idx * stride_qh + offsets_dmodel * stride_qd
6061
q = tl.load(Q + offsets_q)
61-
6262
# block table for the current sequence
6363
block_table_ptr = block_tables + cur_seq_idx * stride_bts
64-
65-
# actually current block table current block start idx
6664
# cur_bt_start_idx = block_start_kv * (BLOCK_KV // BLOCK_SIZE)
67-
cur_bt_start_idx = block_start_kv
68-
cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
69-
70-
if block_start_kv * BLOCK_KV >= cur_kv_seq_len:
71-
return
72-
65+
# cur_block_id = tl.load(block_table_ptr + cur_bt_start_idx * stride_btb)
66+
cur_block_id = tl.load(block_table_ptr + block_start_kv * stride_btb)
7367
cur_occupied_size = tl.where(
7468
(block_start_kv + 1) * BLOCK_SIZE <= cur_kv_seq_len, BLOCK_SIZE, cur_kv_seq_len - block_start_kv * BLOCK_SIZE
7569
)
7670
tl.device_assert(cur_occupied_size >= 0)
7771

72+
cur_kv_head_idx = cur_head_idx // KV_GROUPS
7873
offset_kvcache = cur_block_id * stride_cacheb + cur_kv_head_idx * stride_cacheh
79-
8074
K_block_ptr = tl.make_block_ptr(
8175
base=KCache + offset_kvcache,
8276
shape=(cur_occupied_size, HEAD_DIM),
@@ -115,14 +109,14 @@ def _flash_decoding_fwd_kernel(
115109
acc = acc / l
116110

117111
offsets_mid_o = (
118-
cur_seq_idx * stride_mid_ot
112+
cur_token_idx * stride_mid_ot
119113
+ cur_head_idx * stride_mid_oh
120114
+ block_start_kv * stride_mid_ob
121115
+ offsets_dmodel * stride_mid_od
122116
)
123117
tl.store(mid_o + offsets_mid_o, acc)
124118
offsets_mid_o_lse = (
125-
cur_seq_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
119+
cur_token_idx * stride_mid_o_lset + cur_head_idx * stride_mid_o_lseh + block_start_kv * stride_mid_o_lseb
126120
)
127121
# logsumexp L^(j) = m^(j) + log(l^(j))
128122
tl.store(mid_o_lse + offsets_mid_o_lse, m + tl.log(l))
@@ -135,6 +129,7 @@ def _flash_decoding_fwd_reduce_kernel(
135129
mid_o_lse, # [batch_size, head_num, kv_split_num]
136130
O, # [batch_size, num_heads, head_dim] or [batch_size, 1, num_heads, head_dim]
137131
kv_seq_len,
132+
q_len,
138133
batch_size,
139134
stride_mid_ot,
140135
stride_mid_oh,
@@ -149,7 +144,8 @@ def _flash_decoding_fwd_reduce_kernel(
149144
BLOCK_KV: tl.constexpr,
150145
HEAD_DIM: tl.constexpr,
151146
):
152-
cur_seq_idx = tl.program_id(0)
147+
cur_token_idx = tl.program_id(0)
148+
cur_seq_idx = cur_token_idx // q_len
153149
if cur_seq_idx >= batch_size:
154150
return
155151
cur_head_idx = tl.program_id(1)
@@ -164,8 +160,8 @@ def _flash_decoding_fwd_reduce_kernel(
164160
l = 0.0 # sum exp
165161
acc = tl.zeros([HEAD_DIM], dtype=tl.float32)
166162

167-
offsets_mid_o = cur_seq_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
168-
offset_mid_lse = cur_seq_idx * stride_o_lset + cur_head_idx * stride_o_lseh
163+
offsets_mid_o = cur_token_idx * stride_mid_ot + cur_head_idx * stride_mid_oh + offsets_dmodel
164+
offset_mid_lse = cur_token_idx * stride_o_lset + cur_head_idx * stride_o_lseh
169165
for block_i in range(0, kv_split_num, 1):
170166
mid_o_block = tl.load(mid_o + offsets_mid_o + block_i * stride_mid_ob)
171167
lse = tl.load(mid_o_lse + offset_mid_lse + block_i * stride_o_lseb)
@@ -179,7 +175,7 @@ def _flash_decoding_fwd_reduce_kernel(
179175
m_i = m_ij
180176

181177
acc = acc / l
182-
offsets_O = cur_seq_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
178+
offsets_O = cur_token_idx * stride_ot + cur_head_idx * stride_oh + offsets_dmodel
183179
tl.store(O + offsets_O, acc.to(O.type.element_ty))
184180
return
185181

@@ -199,32 +195,40 @@ def flash_decoding_attention(
199195
mid_output_lse: torch.Tensor = None,
200196
sm_scale: int = None,
201197
kv_group_num: int = 1,
198+
q_len: int = 1,
202199
):
203200
"""
204201
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
205202
206203
Args:
207-
q (torch.Tensor): [bsz, num_heads, head_dim]
204+
q (torch.Tensor): [bsz * q_len, num_heads, head_dim]
205+
q_len > 1 only for verification process in speculative-decoding.
208206
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
209207
v_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim]
210208
kv_seq_len (torch.Tensor): [batch_size]
211209
records the (kv) sequence lengths incorporating past kv sequence lengths.
212210
block_tables (torch.Tensor): [batch_size, max_blocks_per_sequence]
213211
max_seq_len_in_batch (int): Maximum sequence length in the batch.
214212
output (torch.Tensor): [bsz, num_heads * head_dim]
215-
mid_output (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num, head_dim]
213+
mid_output (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num, head_dim]
216214
Intermediate output tensor. `max_bsz` should be greater than or equal to `bsz`.
217-
mid_output_lse (torch.Tensor): [ max_bsz , num_heads, kv_max_split_num]
215+
q_len > 1 only for verification process in speculative-decoding.
216+
mid_output_lse (torch.Tensor): [max_bsz * q_len, num_heads, kv_max_split_num]
218217
Log-sum-exp of intermediate output. `max_bsz` should be greater than or equal to `bsz`.
218+
q_len > 1 only for verification process in speculative-decoding.
219219
block_size (int): Size of each block in the blocked key/value cache.
220220
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
221+
q_length (int): Query length. Use for speculative decoding when `q_length` > 1 (i.e. the last n tokens).
222+
Defaults to 1.
221223
222224
Returns:
223-
Output tensor with shape [bsz, num_heads * head_dim]
225+
Output tensor with shape [bsz * q_len, num_heads * head_dim]
224226
"""
225227
q = q.squeeze() if q.dim() == 4 else q
226228
assert q.dim() == 3, f"Incompatible q dim: {q.dim()}"
227-
bsz, num_heads, head_dim = q.shape
229+
n_tokens, num_heads, head_dim = q.shape
230+
assert n_tokens % q_len == 0, "Invalid q_len"
231+
bsz = n_tokens // q_len
228232

229233
assert head_dim in {32, 64, 128, 256}
230234
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
@@ -247,22 +251,31 @@ def flash_decoding_attention(
247251
max_seq_len_in_batch = kv_seq_len.max().item() if max_seq_len_in_batch is None else max_seq_len_in_batch
248252
# For compatibility (TODO revise modeling in future)
249253
kv_max_split_num = (max_seq_len_in_batch + BLOCK_KV - 1) // BLOCK_KV
250-
mid_output = (
251-
torch.zeros(size=(bsz, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device)
252-
if mid_output is None
253-
else mid_output
254-
)
255-
mid_output_lse = (
256-
torch.zeros(size=(bsz, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
257-
if mid_output_lse is None
258-
else mid_output_lse
259-
)
254+
255+
if mid_output is None:
256+
mid_output = torch.empty(
257+
(bsz * q_len, num_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
258+
)
259+
if mid_output_lse is None:
260+
mid_output_lse = torch.empty((bsz * q_len, num_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
261+
if output is None:
262+
# A hack to prevent `view` operation in modeling
263+
output = torch.empty((bsz * q_len, num_heads * head_dim), dtype=q.dtype, device=q.device)
264+
265+
assert (
266+
mid_output.size(2) == mid_output_lse.size(2) >= kv_max_split_num
267+
), "Incompatible kv split number of intermediate output tensors"
268+
assert (
269+
mid_output.size(0) == mid_output_lse.size(0) >= output.size(0) == n_tokens
270+
), f"Incompatible first dimension of output tensors"
260271

261272
# NOTE use `triton.next_power_of_2` here to utilize the cache mechanism of triton
262273
# To optimize, revise batching/scheduling to batch 2^n sequences in a batch (preferred)
263-
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
264-
output = torch.empty((bsz, num_heads * head_dim), dtype=q.dtype, device=q.device) if output is None else output
265-
274+
grid = (
275+
triton.next_power_of_2(bsz * q_len),
276+
num_heads,
277+
triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV),
278+
)
266279
_flash_decoding_fwd_kernel[grid](
267280
q,
268281
k_cache,
@@ -271,6 +284,7 @@ def flash_decoding_attention(
271284
mid_output,
272285
mid_output_lse,
273286
kv_seq_len,
287+
q_len,
274288
bsz,
275289
q.stride(0),
276290
q.stride(1),
@@ -295,13 +309,13 @@ def flash_decoding_attention(
295309
HEAD_DIM=head_dim,
296310
)
297311

298-
grid = (triton.next_power_of_2(bsz), num_heads)
299-
312+
grid = (triton.next_power_of_2(bsz * q_len), num_heads)
300313
_flash_decoding_fwd_reduce_kernel[grid](
301314
mid_output,
302315
mid_output_lse,
303316
output,
304317
kv_seq_len,
318+
q_len,
305319
bsz,
306320
mid_output.stride(0),
307321
mid_output.stride(1),

colossalai/kernel/triton/kvcache_copy.py

Lines changed: 106 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,50 @@
33
import triton.language as tl
44

55

6+
# Triton 2.1.0
7+
@triton.jit
8+
def _copy_to_kcache_seqlen_n_kernel(
9+
KV, # K or V
10+
KVCache, # KCache or VCache
11+
BLOCK_TABLES,
12+
context_lengths,
13+
stride_kt,
14+
stride_kh,
15+
stride_kd,
16+
stride_cacheb,
17+
stride_cacheh,
18+
stride_cachebs,
19+
stride_cached,
20+
stride_bts,
21+
stride_btb,
22+
block_size,
23+
n,
24+
HEAD_DIM: tl.constexpr,
25+
):
26+
cur_token_idx = tl.program_id(0)
27+
cur_seq_idx = cur_token_idx // n
28+
cur_token_shift = cur_token_idx - (n * (cur_seq_idx + 1))
29+
# cur_token_shift = cur_token_idx - n * cur_seq_idx
30+
cur_kv_head_idx = tl.program_id(1)
31+
32+
past_kv_seq_len = tl.load(context_lengths + cur_seq_idx) + cur_token_shift
33+
last_bt_block_idx = past_kv_seq_len // block_size
34+
block_table_ptr = BLOCK_TABLES + cur_seq_idx * stride_bts
35+
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
36+
offset_last_block = past_kv_seq_len % block_size
37+
offsets_dmodel = tl.arange(0, HEAD_DIM)
38+
offsets_kv = cur_token_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
39+
kv = tl.load(KV + offsets_kv)
40+
offsets_kvcache = (
41+
block_id * stride_cacheb
42+
+ cur_kv_head_idx * stride_cacheh
43+
+ offset_last_block * stride_cachebs
44+
+ offsets_dmodel * stride_cached
45+
)
46+
tl.store(KVCache + offsets_kvcache, kv)
47+
return
48+
49+
650
# Triton 2.1.0
751
@triton.jit
852
def _copy_to_kvcache_seqlen1_kernel(
@@ -40,10 +84,11 @@ def _copy_to_kvcache_seqlen1_kernel(
4084
block_id = tl.load(block_table_ptr + last_bt_block_idx * stride_btb)
4185
offsets_in_last_block = past_kv_seq_len % block_size
4286
offsets_dmodel = tl.arange(0, HEAD_DIM)
43-
offsets_kv = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
87+
offsets_k = cur_seq_idx * stride_kt + cur_kv_head_idx * stride_kh + offsets_dmodel * stride_kd
88+
offsets_v = cur_seq_idx * stride_vt + cur_kv_head_idx * stride_vh + offsets_dmodel * stride_vd
4489

45-
k = tl.load(K + offsets_kv)
46-
v = tl.load(V + offsets_kv)
90+
k = tl.load(K + offsets_k)
91+
v = tl.load(V + offsets_v)
4792

4893
offsets_kcache = (
4994
block_id * stride_cachekb
@@ -63,6 +108,64 @@ def _copy_to_kvcache_seqlen1_kernel(
63108
return
64109

65110

111+
def copy_k_to_blocked_cache(
112+
k: torch.Tensor, k_cache: torch.Tensor, kv_lengths: torch.Tensor, block_tables: torch.Tensor, n: int = 1
113+
):
114+
"""
115+
Copy keys or values to the blocked key/value cache during decoding stage.
116+
117+
Args:
118+
k (torch.Tensor): [bsz, 1, num_kv_heads, head_dim]/[bsz, num_kv_heads, head_dim] - Keys or values during decoding with seq len 1.
119+
[bsz * n, num_kv_heads, head_dim] - Keys or values with seq len n
120+
k_cache (torch.Tensor): [num_blocks, num_kv_heads, block_size, head_dim] - Blocked key or value cache.
121+
kv_lengths (torch.Tensor): [bsz] - Past key/value sequence lengths plus current sequence length for each sequence.
122+
block_tables (torch.Tensor): [bsz, max_blocks_per_sequence] - Block tables for each sequence.
123+
n (int): Number of tokens to copy for each sequence. Default to 1.
124+
"""
125+
assert k.size(-1) == k_cache.size(-1), "Incompatible head dim"
126+
assert k.dtype == k_cache.dtype, "Expected consistent dtype for tensor and cache."
127+
128+
k = k.reshape(-1, k.size(-2), k.size(-1)) if k.dim() == 4 else k
129+
assert k.dim() == 3, f"Invalid k dim {k.dim()}"
130+
bsz, num_kv_heads, head_dim = k.shape
131+
# NOTE when n > 1, the shape of k is [bsz * n, num_kv_heads, head_dim]
132+
if n > 1:
133+
assert bsz % n == 0, "Each sequence should have the same number of tokens to be copied"
134+
bsz = bsz // n
135+
136+
assert kv_lengths.shape[0] == block_tables.shape[0] == bsz, (
137+
f"Got incompatible batch size (number of seqs):\n"
138+
f" Past kv sequence lengths bsz {kv_lengths.shape[0]}; "
139+
f" block tables bsz {block_tables.shape[0]}, input k batch size {bsz}"
140+
)
141+
142+
# Modify if the shape of kv cahce is changed.
143+
block_size = k_cache.size(-2)
144+
145+
num_warps = 8 if head_dim > 128 else 4
146+
147+
grid = (bsz * n, num_kv_heads)
148+
_copy_to_kcache_seqlen_n_kernel[grid](
149+
k,
150+
k_cache,
151+
block_tables,
152+
kv_lengths,
153+
k.stride(0),
154+
k.stride(1),
155+
k.stride(2),
156+
k_cache.stride(0),
157+
k_cache.stride(1),
158+
k_cache.stride(2),
159+
k_cache.stride(3),
160+
block_tables.stride(0),
161+
block_tables.stride(1),
162+
block_size,
163+
n=n,
164+
HEAD_DIM=head_dim,
165+
num_warps=num_warps,
166+
)
167+
168+
66169
def copy_kv_to_blocked_cache(
67170
k: torch.Tensor,
68171
v: torch.Tensor,

0 commit comments

Comments
 (0)