@@ -136,14 +136,14 @@ def main(
136136                    KV_shared ,
137137                    acc_s ,
138138                    transpose_B = True ,
139-                     policy = T .GemmWarpPolicy .FullCol ,
139+                     policy = T .GemmWarpPolicy .FullRow ,
140140                )
141141                T .gemm (
142142                    Q_tail_shared ,
143143                    K_tail_shared ,
144144                    acc_s ,
145145                    transpose_B = True ,
146-                     policy = T .GemmWarpPolicy .FullCol ,
146+                     policy = T .GemmWarpPolicy .FullRow ,
147147                )
148148                T .copy (m_i , m_i_prev )
149149                T .reduce_max (acc_s , m_i , dim = 1 , clear = False )
@@ -158,7 +158,7 @@ def main(
158158                    acc_o [h_i , d_i ] =  acc_o [h_i , d_i ] *  alpha [h_i ]
159159
160160                T .copy (acc_s , S_shared )
161-                 T .gemm (S_shared , KV_shared , acc_o , policy = T .GemmWarpPolicy .FullCol )
161+                 T .gemm (S_shared , KV_shared , acc_o , policy = T .GemmWarpPolicy .FullRow )
162162
163163            # Rescale 
164164            for  h_i , d_i  in  T .Parallel (H_per_block , D ):
@@ -174,7 +174,15 @@ def main(
174174    return  main 
175175
176176
177- def  sparse_mla_fwd_interface (q , kv , indices , sm_scale = None , return_p_sum : bool  =  False , d_v = 512 ):
177+ def  sparse_mla_fwd_interface (q ,
178+                              kv ,
179+                              indices ,
180+                              sm_scale = None ,
181+                              return_p_sum : bool  =  False ,
182+                              d_v = 512 ,
183+                              block_I = 64 ,
184+                              num_stages = 2 ,
185+                              threads = 256 ):
178186    is_casual  =  True 
179187    assert  return_p_sum  ==  False , "This kernel file is for fwd only" 
180188    assert  q .is_contiguous () and  kv .is_contiguous () and  indices .is_contiguous ()
@@ -190,7 +198,17 @@ def sparse_mla_fwd_interface(q, kv, indices, sm_scale=None, return_p_sum: bool =
190198    _ , _ , _ , topk  =  indices .shape 
191199    assert  indices .shape  ==  (batch , seq_len , kv_group , topk )
192200
193-     kernel  =  sparse_mla_fwd (heads , dim , tail_dim , topk , kv_group , sm_scale , is_casual )
201+     kernel  =  sparse_mla_fwd (
202+         heads ,
203+         dim ,
204+         tail_dim ,
205+         topk ,
206+         kv_group ,
207+         sm_scale ,
208+         is_casual ,
209+         block_I = block_I ,
210+         num_stages = num_stages ,
211+         threads = threads )
194212    out , lse  =  kernel (q , kv , indices )
195213    return  out , lse 
196214
@@ -241,7 +259,10 @@ def test_sparse_mla_fwd(B=1,
241259                        DV = 512 ,
242260                        topk = 2048 ,
243261                        dtype = torch .bfloat16 ,
244-                         check_correctness = True ):
262+                         check_correctness = True ,
263+                         block_I = 64 ,
264+                         num_stages = 2 ,
265+                         threads = 256 ):
245266    torch .random .manual_seed (0 )
246267    q  =  torch .randn ((B , S , H , DQK ), dtype = dtype , device = "cuda" ).requires_grad_ (True )
247268    kv  =  torch .randn ((B , SKV , HKV , DQK ), dtype = dtype , device = "cuda" ).requires_grad_ (True )
@@ -253,7 +274,8 @@ def test_sparse_mla_fwd(B=1,
253274                i_i  =  torch .randperm (max (1 , t ))[:topk ]
254275                indices [b , t , h , :len (i_i )] =  i_i 
255276
256-     tl_out , tl_lse  =  sparse_mla_fwd_interface (q , kv , indices )
277+     tl_out , tl_lse  =  sparse_mla_fwd_interface (
278+         q , kv , indices , block_I = block_I , num_stages = num_stages , threads = threads )
257279
258280    if  check_correctness :
259281        # otherwise may cause out of memory 
@@ -262,7 +284,8 @@ def test_sparse_mla_fwd(B=1,
262284        print ("assert_tensors_similar passed" )
263285
264286    def  fn ():
265-         return  sparse_mla_fwd_interface (q , kv , indices )
287+         return  sparse_mla_fwd_interface (
288+             q , kv , indices , block_I = block_I , num_stages = num_stages , threads = threads )
266289
267290    from  tilelang .profiler  import  do_bench 
268291
@@ -287,4 +310,7 @@ def fn():
287310        DV = 512 ,
288311        topk = 2048 ,
289312        dtype = torch .bfloat16 ,
290-         check_correctness = True )
313+         check_correctness = True ,
314+         block_I = 64 ,
315+         num_stages = 2 ,
316+         threads = 256 )
0 commit comments