@@ -54,10 +54,11 @@ def test_flash_decoding(
5454 dtype = torch .float16
5555 device = get_current_device ()
5656
57- if same_context_len :
58- context_lengths = torch .tensor ([max_seq_len for _ in range (bsz )], dtype = torch .int32 , device = device )
59- else :
60- context_lengths = torch .randint (low = 1 , high = max_seq_len , size = (bsz ,), dtype = torch .int32 , device = device )
57+ context_lengths = (
58+ torch .tensor ([max_seq_len for _ in range (bsz )], dtype = torch .int32 , device = device )
59+ if same_context_len
60+ else torch .randint (low = 1 , high = max_seq_len , size = (bsz ,), dtype = torch .int32 , device = device )
61+ )
6162 num_tokens = torch .sum (context_lengths ).item ()
6263
6364 q_size = (bsz , q_len , num_attn_heads , head_dim )
@@ -127,7 +128,8 @@ def test_flash_decoding(
127128configs = [
128129 triton .testing .Benchmark (
129130 x_names = ["KV_LEN" ],
130- x_vals = [2 ** i for i in range (8 , 12 )],
131+ x_vals = [2 ** i for i in range (8 , 14 )],
132+ # x_vals=[x for x in range(256, 8192, 256)],
131133 line_arg = "provider" ,
132134 line_vals = ["torch" , "triton" ],
133135 line_names = ["Torch" , "Triton" ],
@@ -162,10 +164,11 @@ def bench_kernel(
162164 dtype = torch .float16
163165 device = get_current_device ()
164166
165- if same_context_len :
166- kv_lengths = torch .tensor ([KV_LEN for _ in range (bsz )], dtype = torch .int32 , device = device )
167- else :
168- kv_lengths = torch .randint (low = 1 , high = KV_LEN , size = (bsz ,), dtype = torch .int32 , device = device )
167+ kv_lengths = (
168+ torch .tensor ([max_seq_len for _ in range (bsz )], dtype = torch .int32 , device = device )
169+ if same_context_len
170+ else torch .randint (low = 1 , high = max_seq_len , size = (bsz ,), dtype = torch .int32 , device = device )
171+ )
169172 num_tokens = torch .sum (kv_lengths ).item ()
170173
171174 q_size = (bsz , q_len , num_attn_heads , head_dim )
@@ -186,6 +189,7 @@ def bench_kernel(
186189 q = q .view (bsz , q_len , num_attn_heads , head_dim )
187190 max_seq_len = kv_lengths .max ().item () # for random lengths
188191
192+ quantiles = [0.5 , 0.2 , 0.8 ]
189193 if provider == "torch" :
190194 # rebuild (batched) kv with padding for torch attention
191195 # q [bsz, 1, num_heads, head_dim]
@@ -203,9 +207,8 @@ def bench_kernel(
203207 fn = lambda : torch_attn_ref (
204208 q , k_torch , v_torch , torch_padding_mask , bsz , 1 , k_torch .size (1 ), num_attn_heads , num_kv_heads , head_dim
205209 )
206- ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
207- return ms
208- elif provider == "triton" :
210+ ms , min_ms , max_ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep , quantiles = quantiles )
211+ if provider == "triton" :
209212 # the maximum block length splitted on kv should be the kv cache block size
210213 kv_max_split_num = (max_seq_len + block_size - 1 ) // block_size
211214 mid_output = torch .empty (
@@ -227,10 +230,11 @@ def bench_kernel(
227230 kv_group_num = kv_group_num ,
228231 ).unsqueeze (1 )
229232
230- ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
231- return ms
233+ ms , min_ms , max_ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep , quantiles = quantiles )
234+
235+ return ms , min_ms , max_ms
232236
233237
234238if __name__ == "__main__" :
235- # test_flash_decoding(16, 32, 32, 16, 1, True)
236- bench_kernel .run (save_path = "." , print_data = True )
239+ test_flash_decoding (16 , 32 , 32 , 16 , 1 , True )
240+ # bench_kernel.run(save_path=".", print_data=True)
0 commit comments