1111def  get_configs ():
1212    """ 
1313    Generate a list of hyperparameter configuration dictionaries for tuning. 
14-      
14+ 
1515    Each configuration is a dict with keys: 'block_M', 'block_N', 'block_K', 
1616    'num_stages', 'threads', and 'split'. The function returns the Cartesian 
1717    product of the parameter value lists: 
1818    - block_M, block_N, block_K: tiling sizes 
1919    - num_stages: pipeline stages 
2020    - threads: thread counts 
2121    - split: K-splitting factor 
22-      
22+ 
2323    Returns: 
2424        List[dict]: A list of configuration dictionaries covering all combinations. 
2525    """ 
@@ -309,17 +309,20 @@ def main(
309309                C_local [i , j ] =  Bias_shared [j ]
310310
311311            tx  =  T .get_thread_binding ()
312-              
312+ 
313313            for  k  in  T .Pipelined (K  //  block_K , num_stages = num_stages ):
314314                for  copy_i  in  T .serial (block_M  *  block_K  //  threads  //  16 ):
315315                    base  =  copy_i  *  threads  *  16  +  tx  *  16 
316316                    if  sorted_token_ids_shared [base  //  block_K ] !=  - 1 :
317317                        for  copy_j  in  T .vectorized (16 ):
318-                             A_shared [base  //  block_K , base  %  block_K  +  copy_j ] =  A [sorted_token_ids_shared [base  //  block_K ] //  topk , k  *  block_K  +  base  %  block_K  +  copy_j ]
318+                             A_shared [base  //  block_K , base  %  block_K  + 
319+                                      copy_j ] =  A [sorted_token_ids_shared [base  //  block_K ] //  topk ,
320+                                                  k  *  block_K  +  base  %  block_K  +  copy_j ]
319321
320322                T .copy (B [expert_id [0 ], bx  *  block_N , k  *  block_K  //  num_elems_per_byte ], B_shared )
321323                if  fast_dequant :
322-                     get_fast_dequant_twiddling_func ()(B_shared , B_dequantize_shared , Scale_shared , k )
324+                     get_fast_dequant_twiddling_func ()(B_shared , B_dequantize_shared , Scale_shared ,
325+                                                       k )
323326                else :
324327                    get_simple_dequant_func ()(B_shared , B_dequantize_shared , Scale_shared , k )
325328
@@ -331,7 +334,7 @@ def main(
331334            T .copy (C_local , C_shared )
332335            for  i , j  in  T .Parallel (block_M , block_N ):
333336                C [sorted_token_ids_shared [i ] //  topk , sorted_token_ids_shared [i ] %  topk ,
334-                 bx  *  block_N  +  j ] =  C_shared [i , j ]
337+                    bx  *  block_N  +  j ] =  C_shared [i , j ]
335338
336339    return  main 
337340
@@ -366,7 +369,8 @@ def ref_moe(A, qB, Scale, Bias, topk_weights, sorted_token_ids, expert_ids, bloc
366369
367370        # Compute the output for this token-expert pair 
368371        # token_embedding @ B.T + bias 
369-         output  =  torch .matmul (token_embedding .to (torch .bfloat16 ), B .T .to (torch .bfloat16 )) +  Bias [expert_id ]
372+         output  =  torch .matmul (token_embedding .to (torch .bfloat16 ), B .T .to (
373+             torch .bfloat16 )) +  Bias [expert_id ]
370374        output  =  output .to (torch .__getattribute__ (dtypeC ))
371375
372376        # Apply the topk weight 
@@ -491,7 +495,9 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, with_bias=False,
491495    max_val  =  diff .max ()
492496    max_idx  =  diff .argmax ()
493497    print (f"max abs diff: { max_val } { max_idx }  )
494-     assert_similar (output , ref_output , name = "output" , eps = 1e-5 )  # We care about the similarity rather than abs. difference 
498+     assert_similar (
499+         output , ref_output , name = "output" ,
500+         eps = 1e-5 )  # We care about the similarity rather than abs. difference 
495501    print ("All checks pass. ✅" )
496502
497503
0 commit comments