@@ -62,6 +62,8 @@ def test_flash_decoding(
6262
6363 q_size = (bsz , q_len , num_attn_heads , head_dim )
6464 q = torch .empty (size = q_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 )
65+ q = q .view (bsz , q_len , num_attn_heads , head_dim )
66+
6567 kv_size = (num_tokens , 2 * num_kv_heads , head_dim )
6668 kv = torch .empty (size = kv_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 )
6769 k , v = torch .split (kv , [num_kv_heads , num_kv_heads ], dim = - 2 )
@@ -75,16 +77,14 @@ def test_flash_decoding(
7577 )
7678 block_tables = block_tables .to (device = device )
7779
78- q = q .view (bsz , q_len , num_attn_heads , head_dim )
79-
8080 max_seq_len = context_lengths .max ().item ()
8181 # the maximum block length splitted on kv should be the kv cache block size
8282 kv_max_split_num = (max_seq_len + block_size - 1 ) // block_size
8383 mid_output = torch .empty (
8484 size = (bsz , num_attn_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device
8585 )
8686 mid_output_lse = torch .empty (size = (bsz , num_attn_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
87-
87+ sm_scale = 1.0 / ( head_dim ** 0.5 )
8888 out_triton = flash_decoding_attention (
8989 q ,
9090 k_cache ,
@@ -94,15 +94,15 @@ def test_flash_decoding(
9494 max_seq_len ,
9595 mid_output ,
9696 mid_output_lse ,
97- block_size ,
98- kv_group_num ,
97+ block_size = block_size ,
98+ sm_scale = sm_scale ,
99+ kv_group_num = kv_group_num ,
99100 )
100101 out_triton = out_triton .unsqueeze (1 ) # [bsz, 1, num_heads, head_dim]
101102
102103 # rebuild (batched) kv with padding for torch attention
103104 # q [bsz, 1, num_heads, head_dim]
104105 # k/v [num_tokens, num_kv_heads, head_dim]
105- max_seq_len = context_lengths .max ().item ()
106106 k_torch = torch .zeros ((bsz , max_seq_len , num_kv_heads , head_dim ), dtype = k .dtype , device = k .device )
107107 v_torch = torch .zeros_like (k_torch )
108108 prev_len_sum = 0
@@ -126,11 +126,11 @@ def test_flash_decoding(
126126SAME_LEN = True
127127configs = [
128128 triton .testing .Benchmark (
129- x_names = ["PAST_KVLEN " ],
130- x_vals = [2 ** i - 1 for i in range (8 , 16 )],
129+ x_names = ["KV_LEN " ],
130+ x_vals = [2 ** i for i in range (8 , 12 )],
131131 line_arg = "provider" ,
132132 line_vals = ["torch" , "triton" ],
133- line_names = ["torch " , "triton " ],
133+ line_names = ["Torch " , "Triton " ],
134134 styles = [("red" , "-" ), ("blue" , "-" )],
135135 ylabel = "ms" ,
136136 plot_name = f"decoding-block_size-{ BLOCK_SIZE } -batch{ BATCH } " ,
@@ -142,7 +142,7 @@ def test_flash_decoding(
142142@triton .testing .perf_report (configs )
143143def bench_kernel (
144144 bsz ,
145- PAST_KVLEN ,
145+ KV_LEN ,
146146 provider ,
147147 block_size : int ,
148148 kv_group_num : int ,
@@ -152,7 +152,7 @@ def bench_kernel(
152152 rep = 100
153153
154154 num_attn_heads = 16
155- max_num_blocks_per_seq = max (32 , triton .cdiv (PAST_KVLEN + 1 , block_size ))
155+ max_num_blocks_per_seq = max (32 , triton .cdiv (KV_LEN , block_size ))
156156
157157 num_kv_heads = num_attn_heads // kv_group_num
158158 assert isinstance (num_kv_heads , int ) and num_kv_heads > 0 , "Invalid number of kv heads."
@@ -163,11 +163,9 @@ def bench_kernel(
163163 device = get_current_device ()
164164
165165 if same_context_len :
166- past_kv_lengths = torch .tensor ([PAST_KVLEN for _ in range (bsz )], dtype = torch .int32 , device = device )
166+ kv_lengths = torch .tensor ([KV_LEN for _ in range (bsz )], dtype = torch .int32 , device = device )
167167 else :
168- past_kv_lengths = torch .randint (low = 1 , high = PAST_KVLEN , size = (bsz ,), dtype = torch .int32 , device = device )
169-
170- kv_lengths = past_kv_lengths + 1
168+ kv_lengths = torch .randint (low = 1 , high = KV_LEN , size = (bsz ,), dtype = torch .int32 , device = device )
171169 num_tokens = torch .sum (kv_lengths ).item ()
172170
173171 q_size = (bsz , q_len , num_attn_heads , head_dim )
@@ -186,12 +184,12 @@ def bench_kernel(
186184 block_tables = block_tables .to (device = device )
187185
188186 q = q .view (bsz , q_len , num_attn_heads , head_dim )
187+ max_seq_len = kv_lengths .max ().item () # for random lengths
189188
190189 if provider == "torch" :
191190 # rebuild (batched) kv with padding for torch attention
192191 # q [bsz, 1, num_heads, head_dim]
193192 # k/v [num_tokens, num_kv_heads, head_dim]
194- max_seq_len = kv_lengths .max ().item ()
195193 k_torch = torch .zeros ((bsz , max_seq_len , num_kv_heads , head_dim ), dtype = k .dtype , device = k .device )
196194 v_torch = torch .zeros_like (k_torch )
197195 prev_len_sum = 0
@@ -205,14 +203,16 @@ def bench_kernel(
205203 fn = lambda : torch_attn_ref (
206204 q , k_torch , v_torch , torch_padding_mask , bsz , 1 , k_torch .size (1 ), num_attn_heads , num_kv_heads , head_dim
207205 )
206+ ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
207+ return ms
208208 elif provider == "triton" :
209- max_seq_len = kv_lengths .max ().item ()
210209 # the maximum block length splitted on kv should be the kv cache block size
211210 kv_max_split_num = (max_seq_len + block_size - 1 ) // block_size
212211 mid_output = torch .empty (
213212 size = (bsz , num_attn_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device
214213 )
215214 mid_output_lse = torch .empty (size = (bsz , num_attn_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
215+ sm_scale = 1.0 / (head_dim ** 0.5 )
216216 fn = lambda : flash_decoding_attention (
217217 q ,
218218 k_cache ,
@@ -222,12 +222,13 @@ def bench_kernel(
222222 max_seq_len ,
223223 mid_output ,
224224 mid_output_lse ,
225- block_size ,
226- kv_group_num ,
225+ block_size = block_size ,
226+ sm_scale = sm_scale ,
227+ kv_group_num = kv_group_num ,
227228 ).unsqueeze (1 )
228229
229- ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
230- return ms
230+ ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep )
231+ return ms
231232
232233
233234if __name__ == "__main__" :
0 commit comments