33from tilelang import tvm as tvm
44import tilelang .language as T
55from tilelang .intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
6- from tilelang .intrinsics .mfma_macro_generator import (
7- MatrixCoreIntrinEmitter ,)
6+ from tilelang .intrinsics .mfma_macro_generator import MatrixCorePreshuffleIntrinEmitter
87from tilelang .transform import simplify_prim_func
98
109tilelang .testing .set_random_seed (0 )
@@ -22,16 +21,8 @@ def tl_matmul(
2221 b_transposed = True ,
2322 k_pack = 1 ,
2423 b_preshuffle = False ,
24+ b_g2l_load = False ,
2525):
26- assert in_dtype in [
27- "float16" ,
28- "int8" ,
29- ], "Currently only float16 and int8 are supported"
30- assert out_dtype in [
31- "float16" ,
32- "float32" ,
33- "int32" ,
34- ], "Currently only float16, float32 and int32 are supported"
3526
3627 micro_size_x = micro_size_y = micro_size_k = 16
3728
@@ -47,15 +38,14 @@ def tl_matmul(
4738 if b_preshuffle :
4839 block_row_warps = 1
4940 block_col_warps = 4
50- warp_row_tiles = 128
51- warp_col_tiles = 32
41+ warp_row_tiles = 64
42+ warp_col_tiles = 16
5243
53- chunk = 32 * k_pack
44+ chunk = 256 * k_pack
5445
5546 pack_size_k = micro_size_k * k_pack
5647
5748 shared_scope = "shared"
58- cache_write_shared = False
5949
6050 block_M = block_row_warps * warp_row_tiles
6151 block_N = block_col_warps * warp_col_tiles
@@ -68,6 +58,7 @@ def tl_matmul(
6858 pack_size_k , micro_size_y )
6959 else :
7060 B_shape = (N , K ) if b_transposed else (K , N )
61+
7162 A_shared_shape = (block_K , block_M ) if a_transposed else (block_M , block_K )
7263 if b_preshuffle :
7364 B_shared_shape = (block_N // micro_size_y , block_K // pack_size_k , micro_size_y ,
@@ -76,12 +67,6 @@ def tl_matmul(
7667 micro_size_y )
7768 else :
7869 B_shared_shape = (block_N , block_K ) if b_transposed else (block_K , block_N )
79- C_shared_shape = (
80- block_M // micro_size_x ,
81- block_N // micro_size_y ,
82- micro_size_x ,
83- micro_size_y ,
84- )
8570
8671 warp_size = 64
8772 threads = warp_size * (block_row_warps * block_col_warps )
@@ -92,7 +77,7 @@ def tl_matmul(
9277 warp_cols = warp_col_tiles // micro_size_y
9378
9479 # MMA Wrapper to Auto Generate Code for MMA
95- mfma_emitter = MatrixCoreIntrinEmitter (
80+ mfma_emitter = MatrixCorePreshuffleIntrinEmitter (
9681 a_dtype = in_dtype ,
9782 b_dtype = in_dtype ,
9883 accum_dtype = accum_dtype ,
@@ -117,7 +102,6 @@ def main(
117102
118103 A_shared = T .alloc_shared (A_shared_shape , in_dtype , scope = shared_scope )
119104 B_shared = T .alloc_shared (B_shared_shape , in_dtype , scope = shared_scope )
120- C_shared = T .alloc_shared (C_shared_shape , out_dtype , scope = shared_scope )
121105 A_local = T .alloc_local ((warp_rows * local_size_a ), in_dtype )
122106 B_local = T .alloc_local ((warp_cols * local_size_b ), in_dtype )
123107 C_local = T .alloc_local ((warp_rows * warp_cols * local_size_c ), accum_dtype )
@@ -126,12 +110,15 @@ def main(
126110 A_shared : make_swizzle_layout (A_shared ),
127111 })
128112
113+ num_ko = K // block_K
114+ num_ki = block_K // (k_pack * micro_size_k )
115+
129116 # Improve L2 Cache
130117 T .use_swizzle (panel_size = 10 )
131118
132119 T .clear (C_local )
133120
134- for ko in T .Pipelined (( K // block_K ) , num_stages = 0 ):
121+ for ko in T .Pipelined (num_ko , num_stages = 0 ):
135122
136123 # Load A into shared memory
137124 if a_transposed :
@@ -140,7 +127,7 @@ def main(
140127 T .copy (A [by * block_M , ko * block_K ], A_shared )
141128
142129 # Load B into shared memory
143- if b_preshuffle :
130+ if b_g2l_load is False :
144131 if b_transposed :
145132 for j , k , jj , kk in T .Parallel (block_N // micro_size_y ,
146133 block_K // pack_size_k , micro_size_y ,
@@ -153,53 +140,37 @@ def main(
153140 micro_size_y ):
154141 B_shared [k , j , kk , jj ] = B [ko * block_K // pack_size_k + k ,
155142 bx * block_N // micro_size_y + j , kk , jj ]
156- else :
157- if b_transposed :
158- T .copy (B [bx * block_N , ko * block_K ], B_shared )
159- else :
160- T .copy (B [ko * block_K , bx * block_N ], B_shared )
161143
162- for ki in T .serial (0 , ( block_K // ( k_pack * micro_size_k )) ):
144+ for ki in T .serial (0 , num_ki ):
163145
164- # Load A into fragment
146+ # Load A S2L
165147 mfma_emitter .ldmatrix_a (
166148 A_local ,
167149 A_shared ,
168150 ki ,
169151 )
170152
171- # Load B into fragment
172- mfma_emitter .ldmatrix_b (
173- B_local ,
174- B_shared ,
175- ki ,
176- )
153+ if b_g2l_load :
154+ # Load B G2L
155+ mfma_emitter .ldmatrix_b (B_local , B , ki + ko * num_ki , pid_m = by , pid_n = bx )
156+ else :
157+ # Load B S2L
158+ mfma_emitter .ldmatrix_b (
159+ B_local ,
160+ B_shared ,
161+ ki ,
162+ )
177163
178164 # Perform Matrix Multiplication
179165 mfma_emitter .mfma (A_local , B_local , C_local )
180166
181167 # Perform STMatrix
182- if cache_write_shared :
183- mfma_emitter .stmatrix (
184- C_local ,
185- C_shared ,
186- )
187-
188- # Store shared into global
189- for i , j in T .Parallel (block_M , block_N ):
190- C [by * block_M + i , bx * block_N + j ] = C_shared [
191- i // micro_size_x ,
192- j // micro_size_y ,
193- i % micro_size_x ,
194- j % micro_size_y ,
195- ]
196- else :
197- mfma_emitter .stmatrix (
198- C_local ,
199- C ,
200- pid_m = by ,
201- pid_n = bx ,
202- )
168+ mfma_emitter .stmatrix (
169+ C_local ,
170+ C ,
171+ pid_m = by ,
172+ pid_n = bx ,
173+ )
203174
204175 return main
205176
@@ -232,9 +203,10 @@ def assert_tl_matmul_correctness(M,
232203 a_transposed = False ,
233204 b_transposed = True ,
234205 k_pack = 1 ,
235- b_preshuffle = False ):
206+ b_preshuffle = False ,
207+ b_g2l_load = False ):
236208 matmul = tl_matmul (M , N , K , in_dtype , out_dtype , accum_dtype , a_transposed , b_transposed ,
237- k_pack , b_preshuffle )
209+ k_pack , b_preshuffle , b_g2l_load )
238210 print (matmul )
239211 kernel = tilelang .compile (matmul )
240212 src_code = kernel .get_kernel_source ()
@@ -285,30 +257,25 @@ def assert_tl_matmul_correctness(M,
285257
286258 print (C )
287259 print (ref_c )
260+
288261 torch .testing .assert_close (C , ref_c , rtol = 1e-2 , atol = 1e-2 )
289262
290263
291264@tilelang .testing .requires_rocm
292265def test_assert_tl_matmul ():
293- assert_tl_matmul_correctness (128 , 128 , 128 , "int8" , "int32" , accum_dtype = "int32" )
294- assert_tl_matmul_correctness (128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" )
295- assert_tl_matmul_correctness (
296- 128 , 256 , 256 , "int8" , "int32" , b_transposed = False , accum_dtype = "int32" )
297- assert_tl_matmul_correctness (128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , k_pack = 2 )
298-
299266 assert_tl_matmul_correctness (
300- 128 , 128 , 128 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
267+ 256 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
301268 assert_tl_matmul_correctness (
302- 128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
269+ 256 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , b_preshuffle = True )
303270 assert_tl_matmul_correctness (
304- 128 , 256 , 256 , "int8" , "int32" , b_transposed = False , accum_dtype = "int32" , b_preshuffle = True )
271+ 256 , 256 , 256 , "int8" , "int32" , b_transposed = False , accum_dtype = "int32" , b_preshuffle = True )
305272
306273 assert_tl_matmul_correctness (
307- 128 , 256 , 256 , "int8" , "int32" , accum_dtype = "int32" , k_pack = 2 , b_preshuffle = True )
274+ 256 , 256 , 512 , "int8" , "int32" , accum_dtype = "int32" , k_pack = 2 , b_preshuffle = True )
308275 assert_tl_matmul_correctness (
309- 128 ,
310276 256 ,
311277 256 ,
278+ 512 ,
312279 "int8" ,
313280 "int32" ,
314281 b_transposed = False ,
0 commit comments