Skip to content

Commit 5cd298b

Browse files
committed
(trivial) pytest format
1 parent 4529b87 commit 5cd298b

File tree

1 file changed

+20
-16
lines changed

1 file changed

+20
-16
lines changed

tests/test_infer_ops/triton/test_decoding_attn.py

Lines changed: 20 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -54,10 +54,11 @@ def test_flash_decoding(
5454
dtype = torch.float16
5555
device = get_current_device()
5656

57-
if same_context_len:
58-
context_lengths = torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
59-
else:
60-
context_lengths = torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
57+
context_lengths = (
58+
torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
59+
if same_context_len
60+
else torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
61+
)
6162
num_tokens = torch.sum(context_lengths).item()
6263

6364
q_size = (bsz, q_len, num_attn_heads, head_dim)
@@ -127,7 +128,8 @@ def test_flash_decoding(
127128
configs = [
128129
triton.testing.Benchmark(
129130
x_names=["KV_LEN"],
130-
x_vals=[2**i for i in range(8, 12)],
131+
x_vals=[2**i for i in range(8, 14)],
132+
# x_vals=[x for x in range(256, 8192, 256)],
131133
line_arg="provider",
132134
line_vals=["torch", "triton"],
133135
line_names=["Torch", "Triton"],
@@ -162,10 +164,11 @@ def bench_kernel(
162164
dtype = torch.float16
163165
device = get_current_device()
164166

165-
if same_context_len:
166-
kv_lengths = torch.tensor([KV_LEN for _ in range(bsz)], dtype=torch.int32, device=device)
167-
else:
168-
kv_lengths = torch.randint(low=1, high=KV_LEN, size=(bsz,), dtype=torch.int32, device=device)
167+
kv_lengths = (
168+
torch.tensor([max_seq_len for _ in range(bsz)], dtype=torch.int32, device=device)
169+
if same_context_len
170+
else torch.randint(low=1, high=max_seq_len, size=(bsz,), dtype=torch.int32, device=device)
171+
)
169172
num_tokens = torch.sum(kv_lengths).item()
170173

171174
q_size = (bsz, q_len, num_attn_heads, head_dim)
@@ -186,6 +189,7 @@ def bench_kernel(
186189
q = q.view(bsz, q_len, num_attn_heads, head_dim)
187190
max_seq_len = kv_lengths.max().item() # for random lengths
188191

192+
quantiles = [0.5, 0.2, 0.8]
189193
if provider == "torch":
190194
# rebuild (batched) kv with padding for torch attention
191195
# q [bsz, 1, num_heads, head_dim]
@@ -203,9 +207,8 @@ def bench_kernel(
203207
fn = lambda: torch_attn_ref(
204208
q, k_torch, v_torch, torch_padding_mask, bsz, 1, k_torch.size(1), num_attn_heads, num_kv_heads, head_dim
205209
)
206-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
207-
return ms
208-
elif provider == "triton":
210+
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
211+
if provider == "triton":
209212
# the maximum block length splitted on kv should be the kv cache block size
210213
kv_max_split_num = (max_seq_len + block_size - 1) // block_size
211214
mid_output = torch.empty(
@@ -227,10 +230,11 @@ def bench_kernel(
227230
kv_group_num=kv_group_num,
228231
).unsqueeze(1)
229232

230-
ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep)
231-
return ms
233+
ms, min_ms, max_ms = triton.testing.do_bench(fn, warmup=warmup, rep=rep, quantiles=quantiles)
234+
235+
return ms, min_ms, max_ms
232236

233237

234238
if __name__ == "__main__":
235-
# test_flash_decoding(16, 32, 32, 16, 1, True)
236-
bench_kernel.run(save_path=".", print_data=True)
239+
test_flash_decoding(16, 32, 32, 16, 1, True)
240+
# bench_kernel.run(save_path=".", print_data=True)

0 commit comments

Comments
 (0)