@@ -34,9 +34,7 @@ def laneid_kernel(A: T.Tensor((num_threads,), "int32")):
3434
3535
3636@tilelang .jit (out_idx = [- 1 ]) 
37- def  _get_warp_idx_sync_kernel (
38-     num_threads : int  =  128 , warp_size : Optional [int ] =  None 
39- ):
37+ def  _get_warp_idx_sync_kernel (num_threads : int  =  128 , warp_size : Optional [int ] =  None ):
4038
4139    @T .prim_func  
4240    def  warp_idx_sync_kernel (A : T .Tensor ((num_threads ,), "int32" )):
@@ -76,9 +74,7 @@ def warp_group_idx_kernel(A: T.Tensor((num_threads,), "int32")):
7674
7775
7876@tilelang .jit (out_idx = [- 1 ]) 
79- def  _shuffle_elect_kernel (
80-     num_threads : int  =  128 , thread_extent : int  =  64 
81- ):
77+ def  _shuffle_elect_kernel (num_threads : int  =  128 , thread_extent : int  =  64 ):
8278
8379    @T .prim_func  
8480    def  shuffle_elect_kernel (A : T .Tensor ((num_threads ,), "int32" )):
@@ -96,24 +92,18 @@ def run_get_lane_id(num_threads: int = 128, warp_size: Optional[int] = None):
9692    print (kernel .get_kernel_source ())
9793    print (A )
9894    expected_warp_size  =  _resolve_warp_size (warp_size )
99-     ref  =  torch .arange (
100-         num_threads , dtype = A .dtype , device = A .device 
101-     ) %  expected_warp_size 
95+     ref  =  torch .arange (num_threads , dtype = A .dtype , device = A .device ) %  expected_warp_size 
10296    torch .testing .assert_close (A .cpu (), ref .cpu ())
10397    return  A 
10498
10599
106- def  run_get_warp_idx_sync (
107-     num_threads : int  =  128 , warp_size : Optional [int ] =  None 
108- ):
100+ def  run_get_warp_idx_sync (num_threads : int  =  128 , warp_size : Optional [int ] =  None ):
109101    kernel  =  _get_warp_idx_sync_kernel (num_threads , warp_size )
110102    A  =  kernel ()
111103    print (kernel .get_kernel_source ())
112104    print (A )
113105    expected_warp_size  =  _resolve_warp_size (warp_size )
114-     ref  =  torch .arange (
115-         num_threads , dtype = A .dtype , device = A .device 
116-     ) //  expected_warp_size 
106+     ref  =  torch .arange (num_threads , dtype = A .dtype , device = A .device ) //  expected_warp_size 
117107    torch .testing .assert_close (A .cpu (), ref .cpu ())
118108    return  A 
119109
@@ -124,9 +114,7 @@ def run_get_warp_idx(num_threads: int = 128, warp_size: Optional[int] = None):
124114    print (kernel .get_kernel_source ())
125115    print (A )
126116    expected_warp_size  =  _resolve_warp_size (warp_size )
127-     ref  =  torch .arange (
128-         num_threads , dtype = A .dtype , device = A .device 
129-     ) //  expected_warp_size 
117+     ref  =  torch .arange (num_threads , dtype = A .dtype , device = A .device ) //  expected_warp_size 
130118    torch .testing .assert_close (A .cpu (), ref .cpu ())
131119    return  A 
132120
@@ -145,25 +133,19 @@ def run_get_warp_group_idx(
145133    threads_per_group  =  expected_warp_size  *  expected_warps_per_group 
146134    if  threads_per_group  <=  0 :
147135        raise  ValueError ("threads_per_group must be positive." )
148-     ref  =  torch .arange (
149-         num_threads , dtype = A .dtype , device = A .device 
150-     ) //  threads_per_group 
136+     ref  =  torch .arange (num_threads , dtype = A .dtype , device = A .device ) //  threads_per_group 
151137    torch .testing .assert_close (A .cpu (), ref .cpu ())
152138    return  A 
153139
154140
155- def  run_shuffle_elect (
156-     num_threads : int  =  128 , thread_extent : int  =  64 
157- ):
141+ def  run_shuffle_elect (num_threads : int  =  128 , thread_extent : int  =  64 ):
158142    if  thread_extent  <  0 :
159143        raise  ValueError ("thread_extent must be non-negative." )
160144    kernel  =  _shuffle_elect_kernel (num_threads , thread_extent )
161145    A  =  kernel ()
162146    print (kernel .get_kernel_source ())
163147    print (A )
164-     indices  =  torch .arange (
165-         num_threads , device = A .device , dtype = torch .int64 
166-     )
148+     indices  =  torch .arange (num_threads , device = A .device , dtype = torch .int64 )
167149    if  thread_extent  ==  0 :
168150        mask  =  indices  ==  0 
169151    elif  thread_extent  >  0 :
@@ -224,6 +206,7 @@ def test_shuffle_elect_default():
224206def  test_shuffle_elect_block_leader ():
225207    run_shuffle_elect (num_threads = 128 , thread_extent = 0 )
226208
209+ 
227210if  __name__  ==  "__main__" :
228211    tilelang .testing .main ()
229212    # run_get_lane_id() 
0 commit comments