2222
2323
2424@tilelang .jit (pass_configs = {"tl.disable_warp_specialized" : True , "tl.disable_tma_lower" : True }) 
25- def  copy_and_barrier_all_intra_node_kernel (local_rank ,
26-                                            rank ,
27-                                            num_ranks ,
28-                                            M ,
29-                                            K ,
30-                                            block_M ,
31-                                            block_K ,
32-                                            threads ,
33-                                            dtype = "float16" ):
34- 
35-     M_per_rank  =  T .ceildiv (M , num_ranks )
36-     sm_num  =  driver .get_num_sms ()
37-     m_blocks  =  T .ceildiv (M_per_rank , block_M )
38-     k_blocks  =  T .ceildiv (K , block_K )
39-     waves  =  T .ceildiv (m_blocks  *  k_blocks , sm_num )
40- 
41-     @T .macro  
42-     def  copy_kernel (src : T .Tensor ((M_per_rank , K ), dtype ), dst : T .Tensor ((M , K ), dtype ),
43-                     data_shared : T .Tensor ((block_M , block_K ), dtype ), block_id ):
44-         for  w  in  T .serial (waves ):
45-             tile_id  =  sm_num  *  w  +  block_id 
46-             bx  =  tile_id  %  m_blocks 
47-             by  =  tile_id  //  m_blocks 
48- 
49-             if  by  <  k_blocks :
50-                 T .copy (src [bx  *  block_M , by  *  block_K ], data_shared )
51-                 T .copy (data_shared , dst [rank  *  M_per_rank  +  bx  *  block_M , by  *  block_K ])
52- 
53-     @T .macro  
54-     def  barrier_all_intra_node_non_atomic (
55-             sync_buffer : T .Tensor ((3  *  num_ranks ), "uint32" ), block_id ):
56-         if  block_id  ==  0 :
57-             T .barrier_all_blocks_sys (sync_buffer )
58-         # barrier all CTAs 
59-         T .sync_grid (sync_buffer [2  *  num_ranks ])
25+ def  set_signal_kernel (local_rank , num_local_ranks , threads ):
6026
6127    @T .prim_func  
62-     def  local_copy (
63-             A : T .Tensor ((M_per_rank , K ), dtype ),
64-             ag_buffer : T .Tensor ((M , K ), dtype ),
65-             signal_buffer : T .Tensor ((num_ranks ), "uint32" ),
66-             sync_buffer : T .Tensor ((3  *  num_ranks ), "uint32" ),
67-     ):
68-         with  T .Kernel (sm_num , threads = threads ) as  (block_id ):
69-             data_shared  =  T .alloc_shared ((block_M , block_K ), dtype )
70-             T .annotate_layout ({data_shared : tilelang .layout .make_swizzled_layout (data_shared )})
71- 
72-             barrier_all_intra_node_non_atomic (sync_buffer , block_id )
73-             copy_kernel (A , ag_buffer , data_shared , block_id )
28+     def  _set_signal_kernel (signal_buffer : T .Tensor ((num_local_ranks ), "uint32" ),):
29+         with  T .Kernel (1 , threads = threads ):
7430            tx  =  T .get_thread_binding (0 )
75-             if  block_id   ==   0   and   tx  <  num_ranks :   # set symm barrier 
76-                 if  tx  ==  rank :
31+             if  tx  <  num_local_ranks : 
32+                 if  tx  ==  local_rank :
7733                    signal_buffer [tx ] =  1 
7834                else :
7935                    signal_buffer [tx ] =  0 
80-             barrier_all_intra_node_non_atomic (sync_buffer , block_id )
8136
82-     return  local_copy 
37+     return  _set_signal_kernel 
8338
8439
8540@tilelang .jit  
8641def  gemm_kernel (M ,
8742                N ,
8843                K ,
89-                 num_rank ,
9044                local_rank ,
45+                 num_local_rank ,
9146                block_M ,
9247                block_N ,
9348                block_K ,
9449                threads ,
50+                 persistent = False ,
9551                dtype = "float16" ,
9652                accum_dtype = "float" ):
9753
98-     M_per_rank  =  T .ceildiv (M , num_rank )
54+     sm_num  =  driver .get_num_sms ()
55+     m_blocks  =  T .ceildiv (M , block_M )
56+     n_blocks  =  T .ceildiv (N  //  num_local_rank , block_N )
57+     waves  =  T .ceildiv (m_blocks  *  n_blocks , sm_num )
58+     M_per_rank  =  T .ceildiv (M , num_local_rank )
9959    GROUP_SIZE_M  =  8 
10060
10161    @T .prim_func  
10262    def  main (
10363            A : T .Tensor ((M , K ), dtype ),
104-             B : T .Tensor ((K , N  //  num_rank ), dtype ),
105-             signal_buffer : T .Tensor ((num_rank ), "uint32" ),
106-             C : T .Tensor ((M , N  //  num_rank ), dtype ),
64+             B : T .Tensor ((K , N  //  num_local_rank ), dtype ),
65+             signal_buffer : T .Tensor ((num_local_rank ), "uint32" ),
66+             C : T .Tensor ((M , N  //  num_local_rank ), dtype ),
10767    ):
10868        with  T .Kernel (
109-                 T .ceildiv (M , block_M ) *  T .ceildiv (N  //  num_rank , block_N ),
69+                 T .ceildiv (M , block_M ) *  T .ceildiv (N  //  num_local_rank , block_N ),
11070                threads = threads ) as  (bid ):
11171            A_shared  =  T .alloc_shared ((block_M , block_K ), dtype )
11272            B_shared  =  T .alloc_shared ((block_K , block_N ), dtype )
11373            C_shared  =  T .alloc_shared ((block_M , block_N ), dtype )
11474            C_local  =  T .alloc_fragment ((block_M , block_N ), accum_dtype )
11575
11676            num_pid_m  =  T .ceildiv (M , block_M )
117-             num_pid_n  =  T .ceildiv (N  //  num_rank , block_N )
77+             num_pid_n  =  T .ceildiv (N  //  num_local_rank , block_N )
11878            num_pid_in_group  =  GROUP_SIZE_M  *  num_pid_n 
11979            group_id  =  bid  //  num_pid_in_group 
12080            first_pid_m  =  group_id  *  GROUP_SIZE_M 
@@ -140,55 +100,94 @@ def main(
140100            T .copy (C_local , C_shared )
141101            T .copy (C_shared , C [pid_m  *  block_M , pid_n  *  block_N ])
142102
143-     return  main 
103+     @T .prim_func  
104+     def  main_persistent (
105+             A : T .Tensor ((M , K ), dtype ),
106+             B : T .Tensor ((K , N  //  num_local_rank ), dtype ),
107+             signal_buffer : T .Tensor ((num_local_rank ), "uint32" ),
108+             C : T .Tensor ((M , N  //  num_local_rank ), dtype ),
109+     ):
110+         with  T .Kernel (sm_num , threads = threads ) as  (bid ):
111+             A_shared  =  T .alloc_shared ((block_M , block_K ), dtype )
112+             B_shared  =  T .alloc_shared ((block_K , block_N ), dtype )
113+             C_shared  =  T .alloc_shared ((block_M , block_N ), dtype )
114+             C_local  =  T .alloc_fragment ((block_M , block_N ), accum_dtype )
115+ 
116+             for  w  in  T .serial (waves ):
117+                 tile_id  =  bid  +  w  *  sm_num 
118+                 num_pid_m  =  T .ceildiv (M , block_M )
119+                 num_pid_n  =  T .ceildiv (N  //  num_local_rank , block_N )
120+                 num_pid_in_group  =  GROUP_SIZE_M  *  num_pid_n 
121+                 group_id  =  tile_id  //  num_pid_in_group 
122+                 first_pid_m  =  group_id  *  GROUP_SIZE_M 
123+                 group_size_m  =  T .min (num_pid_m  -  first_pid_m , GROUP_SIZE_M )
124+                 pid_m_  =  first_pid_m  +  ((tile_id  %  num_pid_in_group ) %  group_size_m )
125+                 pid_n_  =  (tile_id  %  num_pid_in_group ) //  group_size_m 
126+ 
127+                 # threadblock swizzle 
128+                 #  no stream-k support. only split by m x n 
129+                 m_offset  =  M_per_rank  *  local_rank 
130+                 pid_m_offset  =  T .ceildiv (m_offset , block_M )
131+                 pid_m  =  (pid_m_  +  pid_m_offset ) %  num_pid_m 
132+                 pid_n  =  pid_n_ 
133+ 
134+                 if  pid_n_  *  block_N  <  (N  //  num_local_rank ) and  pid_m_  *  block_M  <  M :
135+                     tid  =  T .get_thread_binding (0 )
136+                     T .clear (C_local )
137+                     if  tid  ==  0 :
138+                         T .wait_eq (signal_buffer [pid_m  *  block_M  //  M_per_rank ], 1 )
139+                     for  k  in  T .Pipelined (T .ceildiv (K , block_K ), num_stages = 3 ):
140+                         T .copy (A [pid_m  *  block_M , k  *  block_K ], A_shared )
141+                         T .copy (B [k  *  block_K , pid_n  *  block_N ], B_shared )
142+                         T .gemm (A_shared , B_shared , C_local )
143+                     T .copy (C_local , C_shared )
144+                     T .copy (C_shared , C [pid_m  *  block_M , pid_n  *  block_N ])
145+ 
146+     return  main  if  not  persistent  else  main_persistent 
144147
145148
146149def  cp_engine_producer_all_gather_full_mesh_pull (
147-     local_tensor ,
148150    ag_buffer ,
149151    signal_buffer ,
150152    M_per_rank ,
151-     N ,
152153    signal_target ,
153-     rank ,
154+     local_rank ,
154155    local_world_size ,
155-     world_size ,
156156    intranode_ag_stream ,
157157):
158-     rank_orders  =  [(rank  +  i ) %  local_world_size  for  i  in  range (local_world_size )]
158+     rank_orders  =  [(local_rank  +  i ) %  local_world_size  for  i  in  range (local_world_size )]
159159
160160    with  torch .cuda .stream (intranode_ag_stream ):
161161        for  src_rank  in  rank_orders :
162-             if  src_rank  ==  rank :
162+             if  src_rank  ==  local_rank :
163163                continue 
164-             dst  =  ag_buffer [rank ][src_rank  *  M_per_rank :(src_rank  +  1 ) *  M_per_rank , :]
164+             dst  =  ag_buffer [local_rank ][src_rank  *  M_per_rank :(src_rank  +  1 ) *  M_per_rank , :]
165165            src  =  ag_buffer [src_rank ][src_rank  *  M_per_rank :(src_rank  +  1 ) *  M_per_rank , :]
166166            dst .copy_ (src )
167167
168168            (err ,) =  cuda .cuStreamWriteValue32 (
169169                intranode_ag_stream .cuda_stream ,
170-                 signal_buffer [rank ][src_rank ].data_ptr (),
170+                 signal_buffer [local_rank ][src_rank ].data_ptr (),
171171                signal_target ,
172172                cuda .CUstreamWriteValue_flags .CU_STREAM_WRITE_VALUE_DEFAULT ,
173173            )
174174
175175
176- def  ag_gemm_op (A , B , C , ag_buffer , signal_buffer , sync_buffer , M_per_rank , N , signal_target , rank ,
177-                group , local_world_size , world_size , local_copy_kernel , gemm_kernel , gemm_stream ,
178-                ag_stream ):
176+ def  ag_gemm_op (A , B , C , ag_buffer , signal_buffer , M_per_rank , N , signal_target , local_rank ,
177+                local_world_size , set_signal_kernel , gemm_kernel , gemm_stream , ag_stream ):
179178
180179    with  torch .cuda .stream (gemm_stream ):
181-         local_copy_kernel (
182-             A , ag_buffer [rank ], signal_buffer [rank ], sync_buffer , stream = gemm_stream .cuda_stream )
180+         set_signal_kernel (signal_buffer [local_rank ], stream = gemm_stream .cuda_stream )
183181
184182    ag_stream .wait_stream (gemm_stream )
185183
186-     cp_engine_producer_all_gather_full_mesh_pull (A ,  ag_buffer , signal_buffer , M_per_rank ,  N ,
187-                                                  signal_target , rank , local_world_size ,  world_size ,
184+     cp_engine_producer_all_gather_full_mesh_pull (ag_buffer , signal_buffer , M_per_rank ,
185+                                                  signal_target , local_rank , local_world_size ,
188186                                                 ag_stream )
189187
190188    with  torch .cuda .stream (gemm_stream ):
191-         gemm_kernel (ag_buffer [rank ], B , signal_buffer [rank ], C , stream = gemm_stream .cuda_stream )
189+         gemm_kernel (
190+             ag_buffer [local_rank ], B , signal_buffer [local_rank ], C , stream = gemm_stream .cuda_stream )
192191
193192    gemm_stream .wait_stream (ag_stream )
194193    current_stream  =  torch .cuda .current_stream ()
@@ -212,6 +211,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
212211    M  =  args .M  if  args  else  8192 
213212    N  =  args .N  if  args  else  8192 
214213    K  =  args .K  if  args  else  8192 
214+     persistent  =  args .persistent 
215215    M_per_rank  =  M  //  num_local_ranks 
216216    N_per_rank  =  N  //  num_local_ranks 
217217
@@ -221,48 +221,45 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
221221    threads  =  256 
222222
223223    rank , num_ranks , group  =  init_dist (local_rank , num_local_ranks )
224+     assert  rank  ==  local_rank  and  num_ranks  ==  num_local_ranks , "only support single node for now" 
224225    allocator  =  tilelang .get_allocator (
225226        size = 2 ** 30 ,
226227        device = "cuda" ,
227228        is_distributed = True ,
228229        local_rank = local_rank ,
229230        num_local_ranks = num_local_ranks ,
230231        group = group )
231-     kernel  =  gemm_kernel (M , N , K , num_ranks , rank , BLOCK_M , BLOCK_N , BLOCK_K , threads )
232-     local_copy_kernel  =  copy_and_barrier_all_intra_node_kernel (
232+     gemm_func  =  gemm_kernel (M , N , K , local_rank , num_local_ranks , BLOCK_M , BLOCK_N , BLOCK_K ,
233+                             threads , persistent )
234+     set_signal_func  =  set_signal_kernel (
233235        local_rank = local_rank ,
234-         rank = local_rank ,
235-         num_ranks = num_ranks ,
236-         M = M ,
237-         K = K ,
238-         block_M = 64 ,
239-         block_K = 64 ,
240-         threads = 128 ,
236+         num_local_ranks = num_local_ranks ,
237+         threads = 32 ,
241238    )
242-     kernel .initialize (allocator = allocator )
243-     local_copy_kernel .initialize (allocator = allocator )
239+     gemm_func .initialize (allocator = allocator )
240+     set_signal_func .initialize (allocator = allocator )
244241    if  local_rank  ==  1 :
245-         print (kernel .get_kernel_source ())
246-         print (local_copy_kernel .get_kernel_source ())
242+         print (gemm_func .get_kernel_source ())
243+         print (set_signal_func .get_kernel_source ())
247244
248-     A  =  tilelang .tensor ((M_per_rank , K ), dtype , allocator = allocator ).normal_ ()
249245    B  =  tilelang .tensor ((K , N_per_rank ), dtype , allocator = allocator ).normal_ ()
250246    C  =  tilelang .tensor ((M , N_per_rank ), dtype , allocator = allocator )
251247    ag_buffer  =  tilelang .tensor ((M , K ), dtype , allocator = allocator , return_peers = True )
248+     A  =  ag_buffer [local_rank ][M_per_rank  *  local_rank :M_per_rank  *  (local_rank  +  1 ), :].normal_ ()
252249    signal_buffer  =  tilelang .tensor ((num_local_ranks ,),
253250                                    torch .uint32 ,
254251                                    allocator = allocator ,
255252                                    return_peers = True )
256-     signal_buffer [rank ].fill_ (0 )  # check if needed 
257-     sync_buffer  =  tilelang .tensor ((3  *  num_ranks ,), torch .uint32 , allocator = allocator )
258253
259254    gemm_stream  =  torch .cuda .Stream ()
260255    ag_stream  =  torch .cuda .Stream (priority = - 1 )
261256    signal_target  =  1 
262257
263-     tilelang_C  =  ag_gemm_op (A , B , C , ag_buffer , signal_buffer , sync_buffer , M_per_rank , K ,
264-                             signal_target , rank , group , num_local_ranks , num_local_ranks ,
265-                             local_copy_kernel , kernel , gemm_stream , ag_stream )
258+     dist .barrier ()
259+ 
260+     tilelang_C  =  ag_gemm_op (A , B , C , ag_buffer , signal_buffer , M_per_rank , K , signal_target ,
261+                             local_rank , num_local_ranks , set_signal_func , gemm_func , gemm_stream ,
262+                             ag_stream )
266263
267264    torch_ag_buffer  =  torch .empty ([M , K ], dtype = dtype , device = "cuda" )
268265    torch_C  =  torch_ag_gemm (group , A , B , torch_ag_buffer )
@@ -273,10 +270,10 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
273270        print (f"rank { local_rank }  )
274271        print (f"torch_C: { torch_C } { tilelang_C }  )
275272
276-     tl_out , tl_t  =  perf_fn (
277-         lambda :  ag_gemm_op ( A ,  B ,  C ,  ag_buffer ,  signal_buffer ,  sync_buffer ,  M_per_rank ,  K , 
278-                             signal_target ,  rank ,  group ,  num_local_ranks ,  num_local_ranks ,
279-                             local_copy_kernel ,  kernel , gemm_stream , ag_stream ),
273+     _ , tl_t  =  perf_fn (
274+         lambda :
275+         ag_gemm_op ( A ,  B ,  C ,  ag_buffer ,  signal_buffer ,  M_per_rank ,  K ,  signal_target ,  local_rank ,
276+                    num_local_ranks ,  set_signal_func ,  gemm_func , gemm_stream , ag_stream ),
280277        warmup = 5 ,
281278        rep = 10 )
282279
@@ -294,6 +291,7 @@ def main(local_rank: int, num_local_ranks: int, args: argparse.Namespace):
294291    parser .add_argument ('--M' , type = int , default = 8192 , help = 'M dimension' )
295292    parser .add_argument ('--N' , type = int , default = 28672 , help = 'N dimension' )
296293    parser .add_argument ('--K' , type = int , default = 8192 , help = 'K dimension' )
294+     parser .add_argument ('--persistent' , action = 'store_true' , help = 'Use persistent kernel' )
297295    args  =  parser .parse_args ()
298296    num_processes  =  args .num_processes 
299297
0 commit comments