1616
1717TRITON_CUDA_SUPPORT = version .parse (torch .version .cuda ) > version .parse ("11.4" )
1818
19+ Q_LEN = 1
20+ HEAD_DIM = 128
21+
1922
2023def prepare_padding_mask (kv_lengths : torch .Tensor , bsz : int , max_seq_len : int , device = "cuda" ):
2124 padding_mask = torch .zeros ((bsz , 1 , 1 , max_seq_len ), dtype = torch .float32 , device = device )
@@ -48,74 +51,72 @@ def test_flash_decoding(
4851
4952 num_kv_heads = num_attn_heads // kv_group_num
5053 assert isinstance (num_kv_heads , int ) and num_kv_heads > 0 , "Invalid number of kv heads."
51- q_len = 1
52- head_dim = 128
5354 max_seq_len = block_size * max_num_blocks_per_seq
5455 dtype = torch .float16
5556 device = get_current_device ()
5657
58+ # Use the provided maximum sequence length for each sequence when testing with teh same context length,
59+ # otherwise generate random context lengths.
5760 context_lengths = (
5861 torch .tensor ([max_seq_len for _ in range (bsz )], dtype = torch .int32 , device = device )
5962 if same_context_len
6063 else torch .randint (low = 1 , high = max_seq_len , size = (bsz ,), dtype = torch .int32 , device = device )
6164 )
6265 num_tokens = torch .sum (context_lengths ).item ()
6366
64- q_size = (bsz , q_len , num_attn_heads , head_dim )
65- q = torch .empty (size = q_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 )
66- q = q .view (bsz , q_len , num_attn_heads , head_dim )
67+ q_size = (bsz , Q_LEN , num_attn_heads , HEAD_DIM )
68+ q = torch .empty (size = q_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 ).transpose (1 , 2 )
69+ kv_unpad_size = (num_tokens , 2 * num_kv_heads , HEAD_DIM )
70+ kv_unpad = torch .empty (size = kv_unpad_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 )
71+ k_unpad , v_unpad = torch .split (kv_unpad , [num_kv_heads , num_kv_heads ], dim = - 2 )
6772
68- kv_size = (num_tokens , 2 * num_kv_heads , head_dim )
69- kv = torch .empty (size = kv_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 )
70- k , v = torch .split (kv , [num_kv_heads , num_kv_heads ], dim = - 2 )
71-
72- cache_shape = (bsz * max_num_blocks_per_seq , num_kv_heads , head_dim , block_size )
73+ cache_shape = (bsz * max_num_blocks_per_seq , num_kv_heads , HEAD_DIM , block_size )
7374 k_cache = torch .zeros (size = cache_shape , dtype = dtype , device = device )
7475 v_cache = torch .zeros (size = cache_shape , dtype = dtype , device = device )
7576 # Mock allocation on block tables as well as blocked kv caches
7677 block_tables = mock_alloc_block_table_and_kvcache (
77- k , v , k_cache , v_cache , context_lengths , bsz , max_num_blocks_per_seq , block_size
78+ k_unpad , v_unpad , k_cache , v_cache , context_lengths , bsz , max_num_blocks_per_seq , block_size
7879 )
7980 block_tables = block_tables .to (device = device )
80-
81- max_seq_len = context_lengths .max ().item ()
82- # the maximum block length splitted on kv should be the kv cache block size
83- kv_max_split_num = (max_seq_len + block_size - 1 ) // block_size
81+ # The maximum sequence length in the batch (if context lengths randomly generated)
82+ max_seq_len_in_b = context_lengths .max ().item ()
83+ # The maximum block length splitted on kv should be the kv cache block size
84+ kv_max_split_num = (max_seq_len_in_b + block_size - 1 ) // block_size
8485 mid_output = torch .empty (
85- size = (bsz , num_attn_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device
86+ size = (bsz , num_attn_heads , kv_max_split_num , HEAD_DIM ), dtype = torch .float32 , device = q .device
8687 )
8788 mid_output_lse = torch .empty (size = (bsz , num_attn_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
88- sm_scale = 1.0 / (head_dim ** 0.5 )
89+ sm_scale = 1.0 / (HEAD_DIM ** 0.5 )
8990 out_triton = flash_decoding_attention (
9091 q ,
9192 k_cache ,
9293 v_cache ,
9394 context_lengths ,
9495 block_tables ,
95- max_seq_len ,
96+ max_seq_len_in_b ,
9697 mid_output ,
9798 mid_output_lse ,
9899 block_size = block_size ,
99100 sm_scale = sm_scale ,
100101 kv_group_num = kv_group_num ,
101- )
102- out_triton = out_triton .unsqueeze (1 ) # [bsz, 1, num_heads, head_dim]
102+ ) # [bsz, 1, num_heads, head_dim]
103103
104104 # rebuild (batched) kv with padding for torch attention
105- # q [bsz, 1, num_heads , head_dim]
106- # k/v [num_tokens , num_kv_heads, head_dim]
107- k_torch = torch .zeros ((bsz , max_seq_len , num_kv_heads , head_dim ), dtype = k .dtype , device = k .device )
105+ # q [bsz, num_heads, q_len , head_dim]
106+ # k/v [bsz, max_seq_len_in_b , num_kv_heads, head_dim]
107+ k_torch = torch .zeros ((bsz , max_seq_len_in_b , num_kv_heads , HEAD_DIM ), dtype = k_unpad .dtype , device = k_unpad .device )
108108 v_torch = torch .zeros_like (k_torch )
109109 prev_len_sum = 0
110110 for i , seq_len in enumerate (context_lengths .tolist ()):
111- # mock left-side padding
112- k_torch [i , - seq_len :, :, :] = k [prev_len_sum : prev_len_sum + seq_len ]
113- v_torch [i , - seq_len :, :, :] = v [prev_len_sum : prev_len_sum + seq_len ]
111+ # left-side padding
112+ k_torch [i , - seq_len :, :, :] = k_unpad [prev_len_sum : prev_len_sum + seq_len ]
113+ v_torch [i , - seq_len :, :, :] = v_unpad [prev_len_sum : prev_len_sum + seq_len ]
114114 prev_len_sum += seq_len
115- # k/v [bsz, max_seq_len, num_kv_heads, head_dim]
116- torch_padding_mask = prepare_padding_mask (context_lengths , bsz , k_torch .size (1 ), q .device )
115+ torch_padding_mask = prepare_padding_mask (context_lengths , bsz , max_seq_len_in_b , q .device )
116+ k_torch = k_torch .transpose (1 , 2 )
117+ v_torch = v_torch .transpose (1 , 2 )
117118 out_torch = torch_attn_ref (
118- q , k_torch , v_torch , torch_padding_mask , bsz , 1 , k_torch . size ( 1 ) , num_attn_heads , num_kv_heads , head_dim
119+ q , k_torch , v_torch , torch_padding_mask , bsz , 1 , max_seq_len_in_b , num_attn_heads , num_kv_heads , HEAD_DIM
119120 )
120121
121122 assert out_torch .shape == out_triton .shape
@@ -128,7 +129,7 @@ def test_flash_decoding(
128129configs = [
129130 triton .testing .Benchmark (
130131 x_names = ["KV_LEN" ],
131- x_vals = [2 ** i for i in range (8 , 14 )],
132+ x_vals = [2 ** i for i in range (8 , 12 )],
132133 # x_vals=[x for x in range(256, 8192, 256)],
133134 line_arg = "provider" ,
134135 line_vals = ["torch" , "triton" ],
@@ -154,30 +155,28 @@ def bench_kernel(
154155 rep = 100
155156
156157 num_attn_heads = 16
157- max_num_blocks_per_seq = max ( 32 , triton .cdiv (KV_LEN , block_size ) )
158+ max_num_blocks_per_seq = triton .cdiv (KV_LEN , block_size )
158159
159160 num_kv_heads = num_attn_heads // kv_group_num
160161 assert isinstance (num_kv_heads , int ) and num_kv_heads > 0 , "Invalid number of kv heads."
161- q_len = 1
162- head_dim = 128
163- max_seq_len = block_size * max_num_blocks_per_seq
162+ block_size * max_num_blocks_per_seq
164163 dtype = torch .float16
165164 device = get_current_device ()
166165
167166 kv_lengths = (
168- torch .tensor ([max_seq_len for _ in range (bsz )], dtype = torch .int32 , device = device )
167+ torch .tensor ([KV_LEN for _ in range (bsz )], dtype = torch .int32 , device = device )
169168 if same_context_len
170- else torch .randint (low = 1 , high = max_seq_len , size = (bsz ,), dtype = torch .int32 , device = device )
169+ else torch .randint (low = 1 , high = KV_LEN , size = (bsz ,), dtype = torch .int32 , device = device )
171170 )
172171 num_tokens = torch .sum (kv_lengths ).item ()
173172
174- q_size = (bsz , q_len , num_attn_heads , head_dim )
175- q = torch .empty (size = q_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 )
176- kv_size = (num_tokens , 2 * num_kv_heads , head_dim )
173+ q_size = (bsz , Q_LEN , num_attn_heads , HEAD_DIM )
174+ q = torch .empty (size = q_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 ). transpose ( 1 , 2 )
175+ kv_size = (num_tokens , 2 * num_kv_heads , HEAD_DIM )
177176 kv = torch .empty (size = kv_size , dtype = dtype , device = device ).normal_ (mean = 0.0 , std = 0.5 )
178177 k , v = torch .split (kv , [num_kv_heads , num_kv_heads ], dim = - 2 )
179178
180- cache_shape = (bsz * max_num_blocks_per_seq , num_kv_heads , head_dim , block_size )
179+ cache_shape = (bsz * max_num_blocks_per_seq , num_kv_heads , HEAD_DIM , block_size )
181180 k_cache = torch .zeros (size = cache_shape , dtype = dtype , device = device )
182181 v_cache = torch .zeros (size = cache_shape , dtype = dtype , device = device )
183182 # Mock allocation on block tables as well as blocked kv caches
@@ -186,55 +185,54 @@ def bench_kernel(
186185 )
187186 block_tables = block_tables .to (device = device )
188187
189- q = q .view (bsz , q_len , num_attn_heads , head_dim )
190- max_seq_len = kv_lengths .max ().item () # for random lengths
188+ max_seq_len_in_b = kv_lengths .max ().item () # for random lengths
191189
192190 quantiles = [0.5 , 0.2 , 0.8 ]
193191 if provider == "torch" :
194192 # rebuild (batched) kv with padding for torch attention
195- # q [bsz, 1, num_heads , head_dim]
196- # k/v [num_tokens , num_kv_heads, head_dim]
197- k_torch = torch .zeros ((bsz , max_seq_len , num_kv_heads , head_dim ), dtype = k .dtype , device = k .device )
193+ # q [bsz, num_heads, q_len , head_dim]
194+ # k/v [bsz, max_seq_len_in_b , num_kv_heads, head_dim]
195+ k_torch = torch .zeros ((bsz , max_seq_len_in_b , num_kv_heads , HEAD_DIM ), dtype = k .dtype , device = k .device )
198196 v_torch = torch .zeros_like (k_torch )
199197 prev_len_sum = 0
200198 for i , seq_len in enumerate (kv_lengths .tolist ()):
201199 # mock left-side padding
202200 k_torch [i , - seq_len :, :, :] = k [prev_len_sum : prev_len_sum + seq_len ]
203201 v_torch [i , - seq_len :, :, :] = v [prev_len_sum : prev_len_sum + seq_len ]
204202 prev_len_sum += seq_len
205- # k/v [bsz, max_seq_len, num_kv_heads, head_dim]
206- torch_padding_mask = prepare_padding_mask (kv_lengths , bsz , k_torch .size (1 ), q .device )
203+ torch_padding_mask = prepare_padding_mask (kv_lengths , bsz , max_seq_len_in_b , q .device )
204+ k_torch = k_torch .transpose (1 , 2 )
205+ v_torch = v_torch .transpose (1 , 2 )
207206 fn = lambda : torch_attn_ref (
208- q , k_torch , v_torch , torch_padding_mask , bsz , 1 , k_torch . size ( 1 ) , num_attn_heads , num_kv_heads , head_dim
207+ q , k_torch , v_torch , torch_padding_mask , bsz , 1 , max_seq_len_in_b , num_attn_heads , num_kv_heads , HEAD_DIM
209208 )
210209 ms , min_ms , max_ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep , quantiles = quantiles )
211210 if provider == "triton" :
212211 # the maximum block length splitted on kv should be the kv cache block size
213- kv_max_split_num = (max_seq_len + block_size - 1 ) // block_size
212+ kv_max_split_num = (max_seq_len_in_b + block_size - 1 ) // block_size
214213 mid_output = torch .empty (
215- size = (bsz , num_attn_heads , kv_max_split_num , head_dim ), dtype = torch .float32 , device = q .device
214+ size = (bsz , num_attn_heads , kv_max_split_num , HEAD_DIM ), dtype = torch .float32 , device = q .device
216215 )
217216 mid_output_lse = torch .empty (size = (bsz , num_attn_heads , kv_max_split_num ), dtype = torch .float32 , device = q .device )
218- sm_scale = 1.0 / (head_dim ** 0.5 )
217+ sm_scale = 1.0 / (HEAD_DIM ** 0.5 )
219218 fn = lambda : flash_decoding_attention (
220219 q ,
221220 k_cache ,
222221 v_cache ,
223222 kv_lengths ,
224223 block_tables ,
225- max_seq_len ,
224+ max_seq_len_in_b ,
226225 mid_output ,
227226 mid_output_lse ,
228227 block_size = block_size ,
229228 sm_scale = sm_scale ,
230229 kv_group_num = kv_group_num ,
231- ).unsqueeze (1 )
232-
230+ ) # [bsz, 1, num_heads, head_dim]
233231 ms , min_ms , max_ms = triton .testing .do_bench (fn , warmup = warmup , rep = rep , quantiles = quantiles )
234232
235233 return ms , min_ms , max_ms
236234
237235
238236if __name__ == "__main__" :
239- test_flash_decoding (16 , 32 , 32 , 16 , 1 , True )
240- # bench_kernel.run(save_path=".", print_data=True)
237+ # test_flash_decoding(16, 32, 32, 16, 1, True)
238+ bench_kernel .run (save_path = "." , print_data = True )
0 commit comments