1313tilelang .disable_cache ()
1414
1515
16- def cannon (MESH , M , N , K , block_M , block_N , block_K , dtype = "float16" ):
16+ def cannon (MESH , M , N , K , block_M , block_N , block_K , dtype = "float16" , specialize = False ):
1717
1818 M_local = T .ceildiv (M , MESH )
1919 N_local = T .ceildiv (N , MESH )
2020 K_local = T .ceildiv (K , MESH )
2121 accum_dtype = "float32"
2222
23+ sm_num = 132 # 132 SMs for H100
24+ total_tiles = T .ceildiv (M_local , block_M ) * T .ceildiv (N_local , block_N )
25+
2326 @T .prim_func
2427 def main (
2528 A : T .Tensor ((2 , M_local , K_local ), dtype ),
@@ -30,8 +33,11 @@ def main(
3033 B_signal_from : T .Tensor ((T .ceildiv (N , block_N ),), "uint64" ),
3134 C : T .Tensor ((M_local , N_local ), dtype ),
3235 ):
33- with T .Kernel (
34- T .ceildiv (M_local , block_M ), T .ceildiv (N_local , block_N ), threads = 128 ) as (bx , by ):
36+ grid_size = T .min (sm_num , total_tiles )
37+ A_rows_per_block = T .ceildiv (M_local , grid_size )
38+ B_cols_per_block = T .ceildiv (N_local , grid_size )
39+ waves = T .ceildiv (total_tiles , sm_num )
40+ with T .Kernel (grid_size , threads = 256 ) as (block_id ):
3541 mype = T .alloc_local ([1 ], "int32" )
3642 npes = T .alloc_local ([1 ], "int32" )
3743 a_peer_from = T .alloc_local ([1 ], "int32" )
@@ -54,71 +60,180 @@ def main(
5460 for ko in T .serial (MESH ):
5561 if tx == 0 :
5662 T .signal_wait_until (
57- T .address_of (A_signal_from [bx ]),
58- T .NVSHMEM_CMP_EQ ,
59- T . ceildiv ( N_local , block_N ) * ko ,
63+ T .address_of (A_signal_from [0 ]),
64+ T .NVSHMEM_CMP_GE ,
65+ total_tiles * ko ,
6066 )
6167 T .signal_wait_until (
62- T .address_of (B_signal_from [by ]),
63- T .NVSHMEM_CMP_EQ ,
64- T . ceildiv ( M_local , block_M ) * ko ,
68+ T .address_of (B_signal_from [0 ]),
69+ T .NVSHMEM_CMP_GE ,
70+ total_tiles * ko ,
6571 )
6672
67- if by == 0 :
73+ if block_id < T . ceildiv ( M_local , A_rows_per_block ) :
6874 T .putmem_signal_nbi_block (
69- T .address_of (A [(ko + 1 ) % 2 , bx * block_M , 0 ]),
70- T .address_of (A [ko % 2 , bx * block_M ,
71- 0 ]), block_M * K_local * dsize_map [dtype ],
72- T .address_of (A_signal_to [bx ]), ko + 1 , T .NVSHMEM_SIGNAL_SET , a_peer_to [0 ])
73- if bx == 0 :
75+ T .address_of (A [(ko + 1 ) % 2 , A_rows_per_block * block_id , 0 ]),
76+ T .address_of (A [ko % 2 , A_rows_per_block * block_id ,
77+ 0 ]), A_rows_per_block * K_local * dsize_map [dtype ],
78+ T .address_of (A_signal_to [0 ]), 1 , T .NVSHMEM_SIGNAL_ADD , a_peer_to [0 ])
79+ if block_id < T . ceildiv ( N_local , B_cols_per_block ) :
7480 T .putmem_signal_nbi_block (
75- T .address_of (B [(ko + 1 ) % 2 , by * block_N , 0 ]),
76- T .address_of (B [ko % 2 , by * block_N ,
77- 0 ]), block_N * K_local * dsize_map [dtype ],
78- T .address_of (B_signal_to [by ]), ko + 1 , T .NVSHMEM_SIGNAL_SET , b_peer_to [0 ])
79-
80- for ki in T .Pipelined (T .ceildiv (K_local , block_K )):
81- T .copy (
82- A [ko % 2 , bx * block_M :(bx + 1 ) * block_M , ki * block_K :(ki + 1 ) * block_K ],
83- A_shared )
84- T .copy (
85- B [ko % 2 , by * block_N :(by + 1 ) * block_N , ki * block_K :(ki + 1 ) * block_K ],
86- B_shared )
87- T .gemm (A_shared , B_shared , C_local , transpose_B = True )
81+ T .address_of (B [(ko + 1 ) % 2 , B_cols_per_block * block_id , 0 ]),
82+ T .address_of (B [ko % 2 , B_cols_per_block * block_id ,
83+ 0 ]), B_cols_per_block * K_local * dsize_map [dtype ],
84+ T .address_of (B_signal_to [0 ]), 1 , T .NVSHMEM_SIGNAL_ADD , b_peer_to [0 ])
85+
86+ for w in T .serial (waves ):
87+
88+ bx = (grid_size * w + block_id ) // T .ceildiv (N_local , block_N )
89+ by = (grid_size * w + block_id ) % T .ceildiv (N_local , block_N )
90+
91+ if bx < T .ceildiv (M_local , block_M ) and by < T .ceildiv (N_local , block_N ):
92+ T .copy (C [bx * block_M , by * block_N ], C_local )
93+ for ki in T .Pipelined (T .ceildiv (K_local , block_K ), num_stages = 4 ):
94+ T .copy (A [ko % 2 , bx * block_M , ki * block_K ], A_shared )
95+ T .copy (B [ko % 2 , by * block_N , ki * block_K ], B_shared )
96+ T .gemm (A_shared , B_shared , C_local , transpose_B = True )
97+
98+ T .copy (C_local , C [bx * block_M , by * block_N ])
99+ if tx == 0 :
100+ T .signal_op (
101+ T .address_of (A_signal_from [0 ]),
102+ 1 ,
103+ T .NVSHMEM_SIGNAL_ADD ,
104+ a_peer_from [0 ],
105+ )
106+ T .signal_op (
107+ T .address_of (B_signal_from [0 ]),
108+ 1 ,
109+ T .NVSHMEM_SIGNAL_ADD ,
110+ b_peer_from [0 ],
111+ )
112+
113+ # TODO: check if __syncthreads() is needed
114+ T .signal_wait_until (
115+ T .address_of (A_signal_to [0 ]),
116+ T .NVSHMEM_CMP_GE ,
117+ (ko + 1 ) * T .ceildiv (M_local , A_rows_per_block ),
118+ )
119+ T .signal_wait_until (
120+ T .address_of (B_signal_to [0 ]),
121+ T .NVSHMEM_CMP_GE ,
122+ (ko + 1 ) * T .ceildiv (N_local , B_cols_per_block ),
123+ )
124+
125+ # TODO: fix correctness
126+ @T .prim_func
127+ def main_specialize (
128+ A : T .Tensor ((2 , M_local , K_local ), dtype ),
129+ B : T .Tensor ((2 , N_local , K_local ), dtype ),
130+ A_signal_to : T .Tensor ((T .ceildiv (M , block_M ),), "uint64" ),
131+ A_signal_from : T .Tensor ((T .ceildiv (M , block_M ),), "uint64" ),
132+ B_signal_to : T .Tensor ((T .ceildiv (N , block_N ),), "uint64" ),
133+ B_signal_from : T .Tensor ((T .ceildiv (N , block_N ),), "uint64" ),
134+ C : T .Tensor ((M_local , N_local ), dtype ),
135+ ):
136+ # 0-compute blocks: compute
137+ # compute_blocks-grid_size: copy
138+ copy_blocks = 20
139+ compute_blocks = T .min (sm_num - copy_blocks , total_tiles )
140+ grid_size = copy_blocks + compute_blocks
141+ A_rows_per_block = T .ceildiv (M_local , copy_blocks )
142+ B_cols_per_block = T .ceildiv (N_local , copy_blocks )
143+ waves = T .ceildiv (total_tiles , compute_blocks )
144+ with T .Kernel (grid_size , threads = 256 ) as (block_id ):
145+ mype = T .alloc_local ([1 ], "int32" )
146+ npes = T .alloc_local ([1 ], "int32" )
147+ a_peer_from = T .alloc_local ([1 ], "int32" )
148+ a_peer_to = T .alloc_local ([1 ], "int32" )
149+ b_peer_from = T .alloc_local ([1 ], "int32" )
150+ b_peer_to = T .alloc_local ([1 ], "int32" )
151+ mype [0 ] = T .get_pe ()
152+ npes [0 ] = T .get_pe_num ()
153+
154+ A_shared = T .alloc_shared ((block_M , block_K ), dtype )
155+ B_shared = T .alloc_shared ((block_N , block_K ), dtype )
156+ C_local = T .alloc_fragment ((block_M , block_N ), accum_dtype )
157+
158+ tx = T .get_thread_binding (0 )
159+ a_peer_from [0 ] = (mype [0 ] + 1 ) % MESH + MESH * (mype [0 ] // MESH )
160+ a_peer_to [0 ] = (mype [0 ] - 1 + MESH ) % MESH + MESH * (mype [0 ] // MESH )
161+ b_peer_from [0 ] = (mype [0 ] + MESH ) % npes [0 ]
162+ b_peer_to [0 ] = (mype [0 ] - MESH + npes [0 ]) % npes [0 ]
163+ T .clear (C_local )
164+ for ko in T .serial (MESH ):
165+ if block_id >= compute_blocks :
166+ if tx == 0 :
167+ T .signal_wait_until (
168+ T .address_of (A_signal_from [0 ]),
169+ T .NVSHMEM_CMP_GE ,
170+ total_tiles * ko ,
171+ )
172+ T .signal_wait_until (
173+ T .address_of (B_signal_from [0 ]),
174+ T .NVSHMEM_CMP_GE ,
175+ total_tiles * ko ,
176+ )
177+ T .putmem_signal_nbi_block (
178+ T .address_of (A [(ko + 1 ) % 2 , A_rows_per_block * (block_id - compute_blocks ),
179+ 0 ]),
180+ T .address_of (A [ko % 2 , A_rows_per_block * (block_id - compute_blocks ),
181+ 0 ]), A_rows_per_block * K_local * dsize_map [dtype ],
182+ T .address_of (A_signal_to [0 ]), 1 , T .NVSHMEM_SIGNAL_ADD , a_peer_to [0 ])
183+ T .putmem_signal_nbi_block (
184+ T .address_of (B [(ko + 1 ) % 2 , B_cols_per_block * (block_id - compute_blocks ),
185+ 0 ]),
186+ T .address_of (B [ko % 2 , B_cols_per_block * (block_id - compute_blocks ),
187+ 0 ]), B_cols_per_block * K_local * dsize_map [dtype ],
188+ T .address_of (B_signal_to [0 ]), 1 , T .NVSHMEM_SIGNAL_ADD , b_peer_to [0 ])
189+
190+ if block_id < compute_blocks :
191+ for w in T .serial (waves ):
192+
193+ bx = (compute_blocks * w + block_id ) // T .ceildiv (N_local , block_N )
194+ by = (compute_blocks * w + block_id ) % T .ceildiv (N_local , block_N )
195+
196+ if bx < T .ceildiv (M_local , block_M ) and by < T .ceildiv (N_local , block_N ):
197+ T .copy (C [bx * block_M , by * block_N ], C_local )
198+ for ki in T .Pipelined (T .ceildiv (K_local , block_K ), num_stages = 4 ):
199+ T .copy (A [ko % 2 , bx * block_M , ki * block_K ], A_shared )
200+ T .copy (B [ko % 2 , by * block_N , ki * block_K ], B_shared )
201+ T .gemm (A_shared , B_shared , C_local , transpose_B = True )
202+
203+ T .copy (C_local , C [bx * block_M , by * block_N ])
204+ if tx == 0 :
205+ T .signal_op (
206+ T .address_of (A_signal_from [0 ]),
207+ 1 ,
208+ T .NVSHMEM_SIGNAL_ADD ,
209+ a_peer_from [0 ],
210+ )
211+ T .signal_op (
212+ T .address_of (B_signal_from [0 ]),
213+ 1 ,
214+ T .NVSHMEM_SIGNAL_ADD ,
215+ b_peer_from [0 ],
216+ )
88217
89- if tx == 0 :
90218 T .signal_wait_until (
91- T .address_of (A_signal_to [bx ]),
92- T .NVSHMEM_CMP_EQ ,
93- ko + 1 ,
219+ T .address_of (A_signal_to [0 ]),
220+ T .NVSHMEM_CMP_GE ,
221+ ( ko + 1 ) * copy_blocks ,
94222 )
95223 T .signal_wait_until (
96- T .address_of (B_signal_to [by ]),
97- T .NVSHMEM_CMP_EQ ,
98- ko + 1 ,
99- )
100- T .signal_op (
101- T .address_of (A_signal_from [bx ]),
102- 1 ,
103- T .NVSHMEM_SIGNAL_ADD ,
104- a_peer_from [0 ],
224+ T .address_of (B_signal_to [0 ]),
225+ T .NVSHMEM_CMP_GE ,
226+ (ko + 1 ) * copy_blocks ,
105227 )
106- T .signal_op (
107- T .address_of (B_signal_from [by ]),
108- 1 ,
109- T .NVSHMEM_SIGNAL_ADD ,
110- b_peer_from [0 ],
111- )
112- T .copy (C_local , C [bx * block_M :(bx + 1 ) * block_M , by * block_N :(by + 1 ) * block_N ])
113228
114- return main
229+ return main_specialize if specialize else main
115230
116231
117232def parse_args ():
118233 parser = argparse .ArgumentParser ()
119- parser .add_argument ("--M" , default = 256 , type = int )
120- parser .add_argument ("--N" , default = 256 , type = int )
121- parser .add_argument ("--K" , default = 256 , type = int )
234+ parser .add_argument ("--M" , default = 16384 , type = int )
235+ parser .add_argument ("--N" , default = 16384 , type = int )
236+ parser .add_argument ("--K" , default = 16384 , type = int )
122237 parser .add_argument ("--warmup" , default = 20 , type = int , help = "warmup iterations" )
123238 parser .add_argument ("--iters" , default = 100 , type = int , help = "perf iterations" )
124239 parser .add_argument ("--dtype" , default = "float16" , type = str , help = "data type" )
@@ -135,14 +250,15 @@ def parse_args():
135250 assert MESH * MESH == WORLD_SIZE , "Mesh size must match world size"
136251
137252 M , N , K = args .M , args .N , args .K
138- block_M , block_N , block_K = 64 , 64 , 64
253+ specialize = False
254+ block_M , block_N , block_K = 128 , 256 , 64
139255 dtype = dtype_map [args .dtype ]
140256
141257 M_local = math .ceil (M / MESH )
142258 N_local = math .ceil (N / MESH )
143259 K_local = math .ceil (K / MESH )
144260
145- func = cannon (MESH , M , N , K , block_M , block_N , block_K , args .dtype )
261+ func = cannon (MESH , M , N , K , block_M , block_N , block_K , args .dtype , specialize )
146262 kernel = tilelang .compile (
147263 func , pass_configs = {
148264 "tl.disable_tma_lower" : True ,
@@ -210,8 +326,67 @@ def parse_args():
210326 print ('-' * 100 )
211327 print (f"[Rank { RANK } ] ✅ Tilelang and Torch match" )
212328 else :
329+ abs_error = torch .abs (C_tilelang - ref )
330+ rel_error = abs_error / (torch .abs (ref ) + 1e-8 )
331+
332+ max_abs_error = abs_error .max ().item ()
333+ max_rel_error = rel_error .max ().item ()
334+ mismatch_ratio = (abs_error > (1e-2 + 1e-2 * torch .abs (ref ))).float ().mean ().item ()
335+
213336 print ('-' * 100 )
214337 print (f"[Rank { RANK } ] ❌ Tilelang and Torch mismatch" )
215338 print (f"[Rank { RANK } ] ref:\n { ref } " )
216339 print (f"[Rank { RANK } ] tilelang:\n { C_tilelang } " )
340+ print (f"[Rank { RANK } ] Mismatch ratio: { mismatch_ratio :.4f} " )
341+ print (f"[Rank { RANK } ] Max absolute error: { max_abs_error :.6f} " )
342+ print (f"[Rank { RANK } ] Max relative error: { max_rel_error :.6f} " )
217343 dist .barrier ()
344+
345+
346+ def bench (func , * args ):
347+ bench_iters = 10
348+ torch .cuda ._sleep (1000000000 )
349+
350+ def preprocess ():
351+ # clear signals
352+ args [2 ].fill_ (0 )
353+ args [3 ].fill_ (0 )
354+ args [4 ].fill_ (0 )
355+ args [5 ].fill_ (0 )
356+
357+ # warmup
358+ for _ in range (20 ):
359+ preprocess ()
360+ _ = func (* args )
361+
362+ st = torch .cuda .Event (enable_timing = True )
363+ ed = torch .cuda .Event (enable_timing = True )
364+ # bench
365+ st .record ()
366+ for _ in range (bench_iters ):
367+ preprocess ()
368+ _ = func (* args )
369+ ed .record ()
370+ torch .cuda .synchronize ()
371+ avg_time = st .elapsed_time (ed ) / bench_iters
372+
373+ return avg_time
374+
375+
376+ def reduce_local_time (local_time ):
377+ tensor = torch .tensor ([local_time ], dtype = torch .float32 ).to ("cuda" )
378+ dist .reduce (tensor , dst = 0 , op = dist .ReduceOp .SUM )
379+ if dist .get_rank () == 0 :
380+ world_size = dist .get_world_size ()
381+ mean_time = (tensor / world_size ).item ()
382+ return mean_time
383+ return None
384+
385+
386+ total_flops = 2 * M * N * K
387+ avg_time = reduce_local_time (
388+ bench (kernel , A , B , A_signal_to , A_signal_from , B_signal_to , B_signal_from , C_tilelang ))
389+
390+ if RANK == 0 :
391+ print (f"avg time of RANK { RANK } : { avg_time } ms" )
392+ print (f"TFlops: { total_flops / avg_time * 1e-9 } TFlops" )
0 commit comments