Skip to content

Commit 4529b87

Browse files
committed
fix kernel api and add benchmark
1 parent b322cf9 commit 4529b87

File tree

2 files changed

+25
-24
lines changed

2 files changed

+25
-24
lines changed

colossalai/kernel/triton/flash_decoding.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,8 @@ def flash_decoding_attention(
190190
mid_output: torch.Tensor,
191191
mid_output_lse: torch.Tensor,
192192
block_size: int,
193-
num_kv_group: int = 1,
193+
sm_scale: int,
194+
kv_group_num: int = 1,
194195
):
195196
"""
196197
Flash decoding implemented with a blocked KV Cache (PagedAttention) during decoding stage.
@@ -227,7 +228,6 @@ def flash_decoding_attention(
227228
)
228229
# NOTE `kv_seq_len` records the (kv) sequence lengths incorporating past kv sequence lengths.
229230
bsz = kv_seq_len.size(0) # e.g. the number of seqs
230-
sm_scale = 1.0 / (head_dim**0.5)
231231

232232
# NOTE BLOCK_KV could be considered as block splitting the sequence on k/v
233233
# For now, BLOCK_KV is supposed to be equivalent with the size of physical cache block (i.e.`block_size`)
@@ -264,7 +264,7 @@ def flash_decoding_attention(
264264
mid_output_lse.stride(1),
265265
mid_output_lse.stride(2),
266266
sm_scale,
267-
KV_GROUPS=num_kv_group,
267+
KV_GROUPS=kv_group_num,
268268
BLOCK_KV=block_size,
269269
BLOCK_SIZE=block_size,
270270
HEAD_DIM=head_dim,

tests/test_infer_ops/triton/test_decoding_attn.py

Lines changed: 22 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ def test_flash_decoding(
6262

6363
q_size = (bsz, q_len, num_attn_heads, head_dim)
6464
q = torch.empty(size=q_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
65+
q = q.view(bsz, q_len, num_attn_heads, head_dim)
66+
6567
kv_size = (num_tokens, 2 * num_kv_heads, head_dim)
6668
kv = torch.empty(size=kv_size, dtype=dtype, device=device).normal_(mean=0.0, std=0.5)
6769
k, v = torch.split(kv, [num_kv_heads, num_kv_heads], dim=-2)
@@ -75,16 +77,14 @@ def test_flash_decoding(
7577
)
7678
block_tables = block_tables.to(device=device)
7779

78-
q = q.view(bsz, q_len, num_attn_heads, head_dim)
79-
8080
max_seq_len = context_lengths.max().item()
8181
# the maximum block length splitted on kv should be the kv cache block size
8282
kv_max_split_num = (max_seq_len + block_size - 1) // block_size
8383
mid_output = torch.empty(
8484
size=(bsz, num_attn_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
8585
)
8686
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
87-
87+
sm_scale = 1.0 / (head_dim**0.5)
8888
out_triton = flash_decoding_attention(
8989
q,
9090
k_cache,
@@ -94,15 +94,15 @@ def test_flash_decoding(
9494
max_seq_len,
9595
mid_output,
9696
mid_output_lse,
97-
block_size,
98-
kv_group_num,
97+
block_size=block_size,
98+
sm_scale=sm_scale,
99+
kv_group_num=kv_group_num,
99100
)
100101
out_triton = out_triton.unsqueeze(1) # [bsz, 1, num_heads, head_dim]
101102

102103
# rebuild (batched) kv with padding for torch attention
103104
# q [bsz, 1, num_heads, head_dim]
104105
# k/v [num_tokens, num_kv_heads, head_dim]
105-
max_seq_len = context_lengths.max().item()
106106
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device)
107107
v_torch = torch.zeros_like(k_torch)
108108
prev_len_sum = 0
@@ -126,11 +126,11 @@ def test_flash_decoding(
126126
SAME_LEN = True
127127
configs = [
128128
triton.testing.Benchmark(
129-
x_names=["PAST_KVLEN"],
130-
x_vals=[2**i - 1 for i in range(8, 16)],
129+
x_names=["KV_LEN"],
130+
x_vals=[2**i for i in range(8, 12)],
131131
line_arg="provider",
132132
line_vals=["torch", "triton"],
133-
line_names=["torch", "triton"],
133+
line_names=["Torch", "Triton"],
134134
styles=[("red", "-"), ("blue", "-")],
135135
ylabel="ms",
136136
plot_name=f"decoding-block_size-{BLOCK_SIZE}-batch{BATCH}",
@@ -142,7 +142,7 @@ def test_flash_decoding(
142142
@triton.testing.perf_report(configs)
143143
def bench_kernel(
144144
bsz,
145-
PAST_KVLEN,
145+
KV_LEN,
146146
provider,
147147
block_size: int,
148148
kv_group_num: int,
@@ -152,7 +152,7 @@ def bench_kernel(
152152
rep = 100
153153

154154
num_attn_heads = 16
155-
max_num_blocks_per_seq = max(32, triton.cdiv(PAST_KVLEN + 1, block_size))
155+
max_num_blocks_per_seq = max(32, triton.cdiv(KV_LEN, block_size))
156156

157157
num_kv_heads = num_attn_heads // kv_group_num
158158
assert isinstance(num_kv_heads, int) and num_kv_heads > 0, "Invalid number of kv heads."
@@ -163,11 +163,9 @@ def bench_kernel(
163163
device = get_current_device()
164164

165165
if same_context_len:
166-
past_kv_lengths = torch.tensor([PAST_KVLEN for _ in range(bsz)], dtype=torch.int32, device=device)
166+
kv_lengths = torch.tensor([KV_LEN for _ in range(bsz)], dtype=torch.int32, device=device)
167167
else:
168-
past_kv_lengths = torch.randint(low=1, high=PAST_KVLEN, size=(bsz,), dtype=torch.int32, device=device)
169-
170-
kv_lengths = past_kv_lengths + 1
168+
kv_lengths = torch.randint(low=1, high=KV_LEN, size=(bsz,), dtype=torch.int32, device=device)
171169
num_tokens = torch.sum(kv_lengths).item()
172170

173171
q_size = (bsz, q_len, num_attn_heads, head_dim)
@@ -186,12 +184,12 @@ def bench_kernel(
186184
block_tables = block_tables.to(device=device)
187185

188186
q = q.view(bsz, q_len, num_attn_heads, head_dim)
187+
max_seq_len = kv_lengths.max().item() # for random lengths
189188

190189
if provider == "torch":
191190
# rebuild (batched) kv with padding for torch attention
192191
# q [bsz, 1, num_heads, head_dim]
193192
# k/v [num_tokens, num_kv_heads, head_dim]
194-
max_seq_len = kv_lengths.max().item()
195193
k_torch = torch.zeros((bsz, max_seq_len, num_kv_heads, head_dim), dtype=k.dtype, device=k.device)
196194
v_torch = torch.zeros_like(k_torch)
197195
prev_len_sum = 0
@@ -205,14 +203,16 @@ def bench_kernel(
205203
fn = lambda: torch_attn_ref(
206204
q, k_torch, v_torch, torch_padding_mask, bsz, 1, k_torch.size(1), num_attn_heads, num_kv_heads, head_dim
207205
)
206+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
207+
return ms
208208
elif provider == "triton":
209-
max_seq_len = kv_lengths.max().item()
210209
# the maximum block length splitted on kv should be the kv cache block size
211210
kv_max_split_num = (max_seq_len + block_size - 1) // block_size
212211
mid_output = torch.empty(
213212
size=(bsz, num_attn_heads, kv_max_split_num, head_dim), dtype=torch.float32, device=q.device
214213
)
215214
mid_output_lse = torch.empty(size=(bsz, num_attn_heads, kv_max_split_num), dtype=torch.float32, device=q.device)
215+
sm_scale = 1.0 / (head_dim**0.5)
216216
fn = lambda: flash_decoding_attention(
217217
q,
218218
k_cache,
@@ -222,12 +222,13 @@ def bench_kernel(
222222
max_seq_len,
223223
mid_output,
224224
mid_output_lse,
225-
block_size,
226-
kv_group_num,
225+
block_size=block_size,
226+
sm_scale=sm_scale,
227+
kv_group_num=kv_group_num,
227228
).unsqueeze(1)
228229

229-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
230-
return ms
230+
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
231+
return ms
231232

232233

233234
if __name__ == "__main__":

0 commit comments

Comments
 (0)