11import  torch 
22import  torch .nn .functional  as  F 
33import  tilelang 
4- from  tilelang .autotuner  import  * 
54import  tilelang .language  as  T 
65from  einops  import  rearrange , einsum 
76import  argparse 
87
98tilelang .disable_cache ()
109
1110
11+ def  get_configs ():
12+     import  itertools 
13+     BLOCK_N  =  [16 , 32 , 64 , 128 ]
14+     BLOCK_H  =  [16 , 32 , 64 , 128 ]
15+     num_split  =  [1 , 2 , 4 , 8 , 16 , 32 ]
16+     threads  =  [128 , 256 ]
17+ 
18+     _configs  =  list (itertools .product (BLOCK_N , BLOCK_H , num_split , threads ))
19+ 
20+     return  [{
21+         "block_N" : c [0 ],
22+         "block_H" : c [1 ],
23+         "num_split" : c [2 ],
24+         "threads" : c [3 ],
25+     } for  c  in  _configs ]
26+ 
27+ 
28+ @tilelang .autotune (configs = get_configs ()) 
1229@tilelang .jit ( 
1330    out_idx = [6 ], pass_configs = { 
1431        tilelang .PassConfigKey .TL_ENABLE_FAST_MATH : True , 
@@ -273,26 +290,39 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
273290
274291if  __name__  ==  "__main__" :
275292    parser  =  argparse .ArgumentParser ()
276-     parser .add_argument ('--batch' , type = int , default = 1 , help = 'batch size' )
293+     parser .add_argument ('--batch' , type = int , default = 128 , help = 'batch size' )
277294    parser .add_argument ('--heads' , type = int , default = 128 , help = 'q heads number' )
278295    parser .add_argument ('--kv_heads' , type = int , default = 1 , help = 'kv heads number' )
279-     parser .add_argument ('--kv_ctx' , type = int , default = 1024 , help = 'kv context length' )
296+     parser .add_argument ('--kv_ctx' , type = int , default = 8192 , help = 'kv context length' )
280297    parser .add_argument ('--dim' , type = int , default = 512 , help = 'head dim' )
281298    parser .add_argument ('--pe_dim' , type = int , default = 64 , help = 'pe head dim' )
282-     parser .add_argument ('--auto_tune ' , action = 'store_true' , help = 'auto tune' )
299+     parser .add_argument ('--autotune ' , action = 'store_true' , help = 'auto tune' )
283300    args  =  parser .parse_args ()
284301    batch , heads , kv_heads , kv_ctx , dim , pe_dim  =  args .batch , args .heads , args .kv_heads , args .kv_ctx , args .dim , args .pe_dim 
285-     enable_autotune  =  args .auto_tune 
302+     enable_autotune  =  args .autotune 
286303
287304    qk_flops  =  2  *  batch  *  heads  *  kv_ctx  *  (dim  +  pe_dim )
288305    pv_flops  =  2  *  batch  *  heads  *  kv_ctx  *  dim 
289306    total_flops  =  qk_flops  +  pv_flops 
290307    BLOCK_N  =  32 
291308    BLOCK_H  =  64 
292309    num_split  =  4 
310+     threads  =  128 
293311
294-     kernel  =  flashmla_decode (batch , heads , kv_heads , kv_ctx , dim , pe_dim , BLOCK_N , BLOCK_H ,
295-                              num_split )
312+     if  enable_autotune :
313+         kernel  =  flashmla_decode (batch , heads , kv_heads , kv_ctx , dim , pe_dim )
314+     else :
315+         kernel  =  flashmla_decode (
316+             batch ,
317+             heads ,
318+             kv_heads ,
319+             kv_ctx ,
320+             dim ,
321+             pe_dim ,
322+             BLOCK_N ,
323+             BLOCK_H ,
324+             num_split ,
325+             threads = threads )
296326    profiler  =  kernel .get_profiler (tensor_supply_type = tilelang .TensorSupplyType .Randn )
297327    input_tensors  =  profiler ._get_inputs ()
298328    tilelang_output  =  kernel (* input_tensors )
@@ -303,35 +333,3 @@ def ref_program(q, q_pe, kv, k_pe, glse, Output_partial):
303333    latency  =  profiler .do_bench (warmup = 500 )
304334    print (f"Latency: { latency }  )
305335    print (f"TFlops: { total_flops  /  latency  *  1e-9 }  )
306- 
307-     # Enable Auto Tuning 
308- 
309- 
310-     def  get_configs ():
311-         import  itertools 
312-         BLOCK_N  =  [16 , 32 , 64 , 128 ]
313-         BLOCK_H  =  [16 , 32 , 64 , 128 ]
314-         num_split  =  [1 , 2 , 4 , 8 , 16 , 32 ]
315-         thread_num  =  [128 , 256 ]
316- 
317-         _configs  =  list (itertools .product (BLOCK_N , BLOCK_H , num_split , thread_num ))
318- 
319-         return  [{
320-             "block_N" : c [0 ],
321-             "block_H" : c [1 ],
322-             "num_split" : c [2 ],
323-             "thread_num" : c [3 ],
324-         } for  c  in  _configs ]
325- 
326-     def  wrapped_kernel (block_N = None , block_H = None , num_split = None , thread_num = None ):
327-         return  flashmla_decode (batch , heads , kv_heads , kv_ctx , dim , pe_dim , block_N , block_H ,
328-                                num_split , thread_num )
329- 
330-     if  enable_autotune :
331-         autotuner  =  AutoTuner .from_kernel (kernel = wrapped_kernel , configs = get_configs ())
332-         tune_result  =  autotuner .run (warmup = 3 , rep = 20 )
333-         best_latency  =  tune_result .latency 
334-         best_config  =  tune_result .config 
335-         print (f"Best latency: { best_latency }  )
336-         print (f"Best TFlops: { total_flops  /  best_latency  *  1e-9 }  )
337-         print (f"Best config: { best_config }  )
0 commit comments