Skip to content

Commit fcfbdde

Browse files
committed
revise flash decoding triton kernel in/out shapes
1 parent 8682115 commit fcfbdde

File tree

4 files changed

+70
-78
lines changed

4 files changed

+70
-78
lines changed

colossalai/kernel/triton/flash_decoding.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
# Triton 2.1.0
1010
@triton.jit
1111
def _flash_decoding_fwd_kernel(
12-
Q, # [batch_size, head_num, head_dim]
12+
Q, # [batch_size, head_num, q_len(1), head_dim]
1313
KCache, # [num_blocks, num_kv_heads, head_dim, block_size]
1414
VCache, # [num_blocks, num_kv_heads, head_dim, block_size]
1515
block_tables, # [batch_size, max_blocks_per_sequence]
@@ -18,6 +18,7 @@ def _flash_decoding_fwd_kernel(
1818
kv_seq_len, # [batch_size]
1919
stride_qt,
2020
stride_qh,
21+
stride_ql,
2122
stride_qd,
2223
stride_cacheb,
2324
stride_cacheh,
@@ -140,6 +141,7 @@ def _flash_decoding_fwd_reduce_kernel(
140141
stride_o_lseh,
141142
stride_o_lseb,
142143
stride_ob,
144+
stride_ol,
143145
stride_oh,
144146
stride_od,
145147
BLOCK_KV: tl.constexpr,
@@ -197,7 +199,7 @@ def flash_decoding_attention(
197199
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
198200
199201
Args:
200-
q (torch.Tensor): [bsz, 1, num_heads, head_dim]
202+
q (torch.Tensor): [bsz, num_heads, q_len(1), head_dim]
201203
k_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
202204
v_cache (torch.Tensor): [num_blocks, num_kv_heads, head_dim, block_size]
203205
kv_seq_len (torch.Tensor): [batch_size]
@@ -211,9 +213,9 @@ def flash_decoding_attention(
211213
num_kv_group (int, optional): Number of key/value groups. Defaults to 1.
212214
213215
Returns:
214-
Output tensor with shape [bsz, num_heads, head_dim]
216+
Output tensor with shape [bsz, num_heads, q_len, head_dim]
215217
"""
216-
bsz, _, num_heads, head_dim = q.shape
218+
bsz, num_heads, _, head_dim = q.shape
217219

218220
assert head_dim in {32, 64, 128, 256}
219221
assert kv_seq_len.shape[0] == block_tables.shape[0] == bsz, (
@@ -234,10 +236,6 @@ def flash_decoding_attention(
234236
assert block_size in {16, 32, 64, 128}
235237
BLOCK_KV = block_size
236238

237-
if q.dim() == 4:
238-
assert q.size(1) == 1, f"q_len is supposed to be 1 but is {q.size(1)}"
239-
q = q.squeeze(1)
240-
241239
grid = (triton.next_power_of_2(bsz), num_heads, triton.cdiv(triton.next_power_of_2(max_seq_len_in_batch), BLOCK_KV))
242240
_flash_decoding_fwd_kernel[grid](
243241
q,
@@ -250,6 +248,7 @@ def flash_decoding_attention(
250248
q.stride(0),
251249
q.stride(1),
252250
q.stride(2),
251+
q.stride(3),
253252
k_cache.stride(0),
254253
k_cache.stride(1),
255254
k_cache.stride(2),
@@ -270,8 +269,8 @@ def flash_decoding_attention(
270269
HEAD_DIM=head_dim,
271270
)
272271

273-
output = torch.empty_like(q)
274-
output = output.view(-1, output.size(-2), output.size(-1))
272+
output = torch.empty_like(q) # already overlapped
273+
output = torch.empty((bsz, 1, num_heads, head_dim), dtype=q.dtype, device=q.device)
275274

276275
grid = (bsz, num_heads)
277276
_flash_decoding_fwd_reduce_kernel[grid](
@@ -289,6 +288,7 @@ def flash_decoding_attention(
289288
output.stride(0),
290289
output.stride(1),
291290
output.stride(2),
291+
output.stride(3),
292292
BLOCK_KV=block_size,
293293
HEAD_DIM=head_dim,
294294
)

tests/test_infer_ops/triton/kernel_utils.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
2121
# src/transformers/models/llama/modeling_llama.py
2222
# https://github.com/huggingface/transformers/blob/633215ba58fe5114d8c8d32e415a04600e010701/src/transformers/models/llama/modeling_llama.py#L350
2323
def torch_attn_ref(
24-
q: torch.Tensor, # [bsz, seq_len, num_heads, head_dim]
25-
k: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim]
26-
v: torch.Tensor, # [bsz, kv_seq_len, num_heads, head_dim]
24+
q: torch.Tensor, # [bsz, num_heads, q_len, head_dim]
25+
k: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
26+
v: torch.Tensor, # [bsz, num_heads, kv_seq_len, head_dim]
2727
attention_mask: torch.Tensor, # [bsz, 1, seq_len, kv_seq_len]
2828
bsz: int,
2929
seq_len: int,
@@ -33,12 +33,6 @@ def torch_attn_ref(
3333
head_dim: int,
3434
):
3535
assert q.shape[-1] == k.shape[-1] == v.shape[-1] == head_dim
36-
q = q.view(bsz, seq_len, num_heads, head_dim)
37-
k = k.view(bsz, kv_seq_len, num_kv_heads, head_dim)
38-
v = v.view(bsz, kv_seq_len, num_kv_heads, head_dim)
39-
q = q.transpose(1, 2)
40-
k = k.transpose(1, 2)
41-
v = v.transpose(1, 2)
4236

4337
# repeat kv for GQA and MQA
4438
# k/v won't change if kv_group_num is 1

tests/test_infer_ops/triton/test_context_attn_unpad.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,9 @@ def torch_attn_unpad(
3434
mask[mask == 0.0] = float("-inf")
3535

3636
torch_attn_ref_out = torch_attn_ref(
37-
q[start_idx:end_idx].unsqueeze(0),
38-
k[start_idx:end_idx].unsqueeze(0),
39-
v[start_idx:end_idx].unsqueeze(0),
37+
q[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
38+
k[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
39+
v[start_idx:end_idx].unsqueeze(0).transpose(1, 2),
4040
mask,
4141
1, # set bsz as 1 as we're processing sequence one by one
4242
seq_len,

tests/test_infer_ops/triton/test_decoding_attn.py

Lines changed: 54 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
TRITON_CUDA_SUPPORT = version.parse(torch.version.cuda) > version.parse("11.4")
1818

19+
Q_LEN = 1
20+
HEAD_DIM = 128
21+
1922

2023
def prepare_padding_mask(kv_lengths: torch.Tensor, bsz: int, max_seq_len: int, device="cuda"):
2124
padding_mask = torch.zeros((bsz, 1, 1, max_seq_len), dtype=torch.float32, device=device)
@@ -48,74 +51,72 @@ def test_flash_decoding(
4851

4952
num_kv_heads = num_attn_heads // kv_group_num
5053
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
51-
q_len = 1
52-
head_dim = 128
5354
max_seq_len = block_size * max_num_blocks_per_seq
5455
dtype = torch.float16
5556
device = get_current_device()
5657

58+
# Use the provided maximum sequence length for each sequence when testing with teh same context length,
59+
# otherwise generate random context lengths.
5760
context_lengths = (
5861
torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
5962
if same_context_len
6063
else torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
6164
)
6265
num_tokens = torch.sum(context_lengths).item()
6366

64-
q_size = (bsz, q_len, num_attn_heads, head_dim)
65-
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
66-
q = q.view(bsz, q_len, num_attn_heads, head_dim)
67+
q_size = (bsz, Q_LEN, num_attn_heads, HEAD_DIM)
68+
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
69+
kv_unpad_size = (num_tokens, 2 * num_kv_heads, HEAD_DIM)
70+
kv_unpad = torch.empty(size=kv_unpad_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
71+
k_unpad, v_unpad = torch.split(kv_unpad, [num_kv_heads, num_kv_heads], dim=-2)
6772

68-
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
69-
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
70-
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
71-
72-
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
73+
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, HEAD_DIM, block_size)
7374
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
7475
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
7576
# Mock allocation on block tables as well as blocked kv caches
7677
block_tables = mock_alloc_block_table_and_kvcache(
77-
k, v, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
78+
k_unpad, v_unpad, k_cache, v_cache, context_lengths, bsz, max_num_blocks_per_seq, block_size
7879
)
7980
block_tables = block_tables.to(device=device)
80-
81-
max_seq_len = context_lengths.max().item()
82-
# the maximum block length splitted on kv should be the kv cache block size
83-
kv_max_split_num = (max_seq_len + block_size - 1) // block_size
81+
# The maximum sequence length in the batch (if context lengths randomly generated)
82+
max_seq_len_in_b = context_lengths.max().item()
83+
# The maximum block length splitted on kv should be the kv cache block size
84+
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
8485
mid_output = torch.empty(
85-
size=(bsz, num_attn_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
86+
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
8687
)
8788
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
88-
sm_scale = 1.0 / (head_dim**0.5)
89+
sm_scale = 1.0 / (HEAD_DIM**0.5)
8990
out_triton = flash_decoding_attention(
9091
q,
9192
k_cache,
9293
v_cache,
9394
context_lengths,
9495
block_tables,
95-
max_seq_len,
96+
max_seq_len_in_b,
9697
mid_output,
9798
mid_output_lse,
9899
block_size=block_size,
99100
sm_scale=sm_scale,
100101
kv_group_num=kv_group_num,
101-
)
102-
out_triton = out_triton.unsqueeze(1) # [bsz, 1, num_heads, head_dim]
102+
) # [bsz, 1, num_heads, head_dim]
103103

104104
# rebuild (batched) kv with padding for torch attention
105-
# q [bsz, 1, num_heads, head_dim]
106-
# k/v [num_tokens, num_kv_heads, head_dim]
107-
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device)
105+
# q [bsz, num_heads, q_len, head_dim]
106+
# k/v [bsz, max_seq_len_in_b, num_kv_heads, head_dim]
107+
k_torch = torch.zeros((bsz, max_seq_len_in_b, num_kv_heads, HEAD_DIM), dtype=k_unpad.dtype, device=k_unpad.device)
108108
v_torch = torch.zeros_like(k_torch)
109109
prev_len_sum = 0
110110
for i, seq_len in enumerate(context_lengths.tolist()):
111-
# mock left-side padding
112-
k_torch[i, -seq_len:, :, :] = k[prev_len_sum : prev_len_sum + seq_len]
113-
v_torch[i, -seq_len:, :, :] = v[prev_len_sum : prev_len_sum + seq_len]
111+
# left-side padding
112+
k_torch[i, -seq_len:, :, :] = k_unpad[prev_len_sum : prev_len_sum + seq_len]
113+
v_torch[i, -seq_len:, :, :] = v_unpad[prev_len_sum : prev_len_sum + seq_len]
114114
prev_len_sum += seq_len
115-
# k/v [bsz, max_seq_len, num_kv_heads, head_dim]
116-
torch_padding_mask = prepare_padding_mask(context_lengths, bsz, k_torch.size(1), q.device)
115+
torch_padding_mask = prepare_padding_mask(context_lengths, bsz, max_seq_len_in_b, q.device)
116+
k_torch = k_torch.transpose(1, 2)
117+
v_torch = v_torch.transpose(1, 2)
117118
out_torch = torch_attn_ref(
118-
q, k_torch, v_torch, torch_padding_mask, bsz, 1, k_torch.size(1), num_attn_heads, num_kv_heads, head_dim
119+
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
119120
)
120121

121122
assert out_torch.shape == out_triton.shape
@@ -128,7 +129,7 @@ def test_flash_decoding(
128129
configs = [
129130
triton.testing.Benchmark(
130131
x_names=["KV_LEN"],
131-
x_vals=[2**i for i in range(8, 14)],
132+
x_vals=[2**i for i in range(8, 12)],
132133
# x_vals=[x for x in range(256, 8192, 256)],
133134
line_arg="provider",
134135
line_vals=["torch", "triton"],
@@ -154,30 +155,28 @@ def bench_kernel(
154155
rep = 100
155156

156157
num_attn_heads = 16
157-
max_num_blocks_per_seq = max(32, triton.cdiv(KV_LEN, block_size))
158+
max_num_blocks_per_seq = triton.cdiv(KV_LEN, block_size)
158159

159160
num_kv_heads = num_attn_heads // kv_group_num
160161
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
161-
q_len = 1
162-
head_dim = 128
163-
max_seq_len = block_size * max_num_blocks_per_seq
162+
block_size * max_num_blocks_per_seq
164163
dtype = torch.float16
165164
device = get_current_device()
166165

167166
kv_lengths = (
168-
torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
167+
torch.tensor([KV_LEN for _ in range(bsz)], dtype=torch.int32, device=device)
169168
if same_context_len
170-
else torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
169+
else torch.randint(low=1, high=KV_LEN, size=(bsz,), dtype=torch.int32, device=device)
171170
)
172171
num_tokens = torch.sum(kv_lengths).item()
173172

174-
q_size = (bsz, q_len, num_attn_heads, head_dim)
175-
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
176-
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
173+
q_size = (bsz, Q_LEN, num_attn_heads, HEAD_DIM)
174+
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5).transpose(1, 2)
175+
kv_size = (num_tokens, 2 * num_kv_heads, HEAD_DIM)
177176
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
178177
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
179178

180-
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, head_dim, block_size)
179+
cache_shape = (bsz * max_num_blocks_per_seq, num_kv_heads, HEAD_DIM, block_size)
181180
k_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
182181
v_cache = torch.zeros(size=cache_shape, dtype=dtype, device=device)
183182
# Mock allocation on block tables as well as blocked kv caches
@@ -186,55 +185,54 @@ def bench_kernel(
186185
)
187186
block_tables = block_tables.to(device=device)
188187

189-
q = q.view(bsz, q_len, num_attn_heads, head_dim)
190-
max_seq_len = kv_lengths.max().item() # for random lengths
188+
max_seq_len_in_b = kv_lengths.max().item() # for random lengths
191189

192190
quantiles = [0.5, 0.2, 0.8]
193191
if provider == "torch":
194192
# rebuild (batched) kv with padding for torch attention
195-
# q [bsz, 1, num_heads, head_dim]
196-
# k/v [num_tokens, num_kv_heads, head_dim]
197-
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device)
193+
# q [bsz, num_heads, q_len, head_dim]
194+
# k/v [bsz, max_seq_len_in_b, num_kv_heads, head_dim]
195+
k_torch = torch.zeros((bsz, max_seq_len_in_b, num_kv_heads, HEAD_DIM), dtype=k.dtype, device=k.device)
198196
v_torch = torch.zeros_like(k_torch)
199197
prev_len_sum = 0
200198
for i, seq_len in enumerate(kv_lengths.tolist()):
201199
# mock left-side padding
202200
k_torch[i, -seq_len:, :, :] = k[prev_len_sum : prev_len_sum + seq_len]
203201
v_torch[i, -seq_len:, :, :] = v[prev_len_sum : prev_len_sum + seq_len]
204202
prev_len_sum += seq_len
205-
# k/v [bsz, max_seq_len, num_kv_heads, head_dim]
206-
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, k_torch.size(1), q.device)
203+
torch_padding_mask = prepare_padding_mask(kv_lengths, bsz, max_seq_len_in_b, q.device)
204+
k_torch = k_torch.transpose(1, 2)
205+
v_torch = v_torch.transpose(1, 2)
207206
fn = lambda: torch_attn_ref(
208-
q, k_torch, v_torch, torch_padding_mask, bsz, 1, k_torch.size(1), num_attn_heads, num_kv_heads, head_dim
207+
q, k_torch, v_torch, torch_padding_mask, bsz, 1, max_seq_len_in_b, num_attn_heads, num_kv_heads, HEAD_DIM
209208
)
210209
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
211210
if provider == "triton":
212211
# the maximum block length splitted on kv should be the kv cache block size
213-
kv_max_split_num = (max_seq_len + block_size - 1) // block_size
212+
kv_max_split_num = (max_seq_len_in_b + block_size - 1) // block_size
214213
mid_output = torch.empty(
215-
size=(bsz, num_attn_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
214+
size=(bsz, num_attn_heads, kv_max_split_num, HEAD_DIM), dtype=torch.float32, device=q.device
216215
)
217216
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
218-
sm_scale = 1.0 / (head_dim**0.5)
217+
sm_scale = 1.0 / (HEAD_DIM**0.5)
219218
fn = lambda: flash_decoding_attention(
220219
q,
221220
k_cache,
222221
v_cache,
223222
kv_lengths,
224223
block_tables,
225-
max_seq_len,
224+
max_seq_len_in_b,
226225
mid_output,
227226
mid_output_lse,
228227
block_size=block_size,
229228
sm_scale=sm_scale,
230229
kv_group_num=kv_group_num,
231-
).unsqueeze(1)
232-
230+
) # [bsz, 1, num_heads, head_dim]
233231
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
234232

235233
return ms, min_ms, max_ms
236234

237235

238236
if __name__ == "__main__":
239-
test_flash_decoding(16, 32, 32, 16, 1, True)
240-
# bench_kernel.run(save_path=".", print_data=True)
237+
# test_flash_decoding(16, 32, 32, 16, 1, True)
238+
bench_kernel.run(save_path=".", print_data=True)

0 commit comments

Comments
 (0)