|
| 1 | +import torch |
| 2 | +import tilelang.testing |
| 3 | +from tilelang import tvm as tvm |
| 4 | +import tilelang.language as T |
| 5 | +from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout |
| 6 | +from tilelang.intrinsics.mfma_macro_generator import ( |
| 7 | + MatrixCoreIntrinEmitter,) |
| 8 | +from tilelang.transform import simplify_prim_func |
| 9 | + |
| 10 | +tilelang.testing.set_random_seed(0) |
| 11 | + |
| 12 | + |
| 13 | +@simplify_prim_func |
| 14 | +def tl_matmul( |
| 15 | + M, |
| 16 | + N, |
| 17 | + K, |
| 18 | + in_dtype, |
| 19 | + out_dtype, |
| 20 | + accum_dtype, |
| 21 | + a_transposed=False, |
| 22 | + b_transposed=True, |
| 23 | + k_pack=1, |
| 24 | + b_preshuffle=False, |
| 25 | +): |
| 26 | + assert in_dtype in [ |
| 27 | + "float16", |
| 28 | + "int8", |
| 29 | + ], "Currently only float16 and int8 are supported" |
| 30 | + assert out_dtype in [ |
| 31 | + "float16", |
| 32 | + "float32", |
| 33 | + "int32", |
| 34 | + ], "Currently only float16, float32 and int32 are supported" |
| 35 | + |
| 36 | + micro_size_x = micro_size_y = micro_size_k = 16 |
| 37 | + |
| 38 | + if in_dtype in {"float8_e4m3fnuz", "int8"}: |
| 39 | + micro_size_k = 32 |
| 40 | + |
| 41 | + block_row_warps = 2 |
| 42 | + block_col_warps = 2 |
| 43 | + warp_row_tiles = 32 |
| 44 | + warp_col_tiles = 32 |
| 45 | + |
| 46 | + # for preshuffle_b, warp_layout = {1, 4} |
| 47 | + if b_preshuffle: |
| 48 | + block_row_warps = 1 |
| 49 | + block_col_warps = 4 |
| 50 | + warp_row_tiles = 128 |
| 51 | + warp_col_tiles = 32 |
| 52 | + |
| 53 | + chunk = 32 * k_pack |
| 54 | + |
| 55 | + pack_size_k = micro_size_k * k_pack |
| 56 | + |
| 57 | + shared_scope = "shared" |
| 58 | + cache_write_shared = False |
| 59 | + |
| 60 | + block_M = block_row_warps * warp_row_tiles |
| 61 | + block_N = block_col_warps * warp_col_tiles |
| 62 | + block_K = chunk |
| 63 | + |
| 64 | + A_shape = (K, M) if a_transposed else (M, K) |
| 65 | + if b_preshuffle: |
| 66 | + B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y, |
| 67 | + pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y, |
| 68 | + pack_size_k, micro_size_y) |
| 69 | + else: |
| 70 | + B_shape = (N, K) if b_transposed else (K, N) |
| 71 | + A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K) |
| 72 | + if b_preshuffle: |
| 73 | + B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y, |
| 74 | + pack_size_k) if b_transposed else (block_K // pack_size_k, |
| 75 | + block_N // micro_size_y, pack_size_k, |
| 76 | + micro_size_y) |
| 77 | + else: |
| 78 | + B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N) |
| 79 | + C_shared_shape = ( |
| 80 | + block_M // micro_size_x, |
| 81 | + block_N // micro_size_y, |
| 82 | + micro_size_x, |
| 83 | + micro_size_y, |
| 84 | + ) |
| 85 | + |
| 86 | + warp_size = 64 |
| 87 | + threads = warp_size * (block_row_warps * block_col_warps) |
| 88 | + local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size |
| 89 | + local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size |
| 90 | + local_size_c = (micro_size_x * micro_size_y) // warp_size |
| 91 | + warp_rows = warp_row_tiles // micro_size_x |
| 92 | + warp_cols = warp_col_tiles // micro_size_y |
| 93 | + |
| 94 | + # MMA Wrapper to Auto Generate Code for MMA |
| 95 | + mfma_emitter = MatrixCoreIntrinEmitter( |
| 96 | + a_dtype=in_dtype, |
| 97 | + b_dtype=in_dtype, |
| 98 | + accum_dtype=accum_dtype, |
| 99 | + a_transposed=a_transposed, |
| 100 | + b_transposed=b_transposed, |
| 101 | + block_row_warps=block_row_warps, |
| 102 | + block_col_warps=block_col_warps, |
| 103 | + warp_row_tiles=warp_row_tiles, |
| 104 | + warp_col_tiles=warp_col_tiles, |
| 105 | + chunk=chunk, |
| 106 | + k_pack=k_pack, |
| 107 | + b_preshuffle=b_preshuffle, |
| 108 | + ) |
| 109 | + |
| 110 | + @T.prim_func |
| 111 | + def main( |
| 112 | + A: T.Tensor(A_shape, in_dtype), |
| 113 | + B: T.Tensor(B_shape, in_dtype), |
| 114 | + C: T.Tensor((M, N), out_dtype), |
| 115 | + ): |
| 116 | + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): |
| 117 | + |
| 118 | + A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope) |
| 119 | + B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope) |
| 120 | + C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope) |
| 121 | + A_local = T.alloc_local((warp_rows * local_size_a), in_dtype) |
| 122 | + B_local = T.alloc_local((warp_cols * local_size_b), in_dtype) |
| 123 | + C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype) |
| 124 | + |
| 125 | + T.annotate_layout({ |
| 126 | + A_shared: make_swizzle_layout(A_shared), |
| 127 | + }) |
| 128 | + |
| 129 | + # Improve L2 Cache |
| 130 | + T.use_swizzle(panel_size=10) |
| 131 | + |
| 132 | + T.clear(C_local) |
| 133 | + |
| 134 | + for ko in T.Pipelined((K // block_K), num_stages=0): |
| 135 | + |
| 136 | + # Load A into shared memory |
| 137 | + if a_transposed: |
| 138 | + T.copy(A[ko * block_K, by * block_M], A_shared) |
| 139 | + else: |
| 140 | + T.copy(A[by * block_M, ko * block_K], A_shared) |
| 141 | + |
| 142 | + # Load B into shared memory |
| 143 | + if b_preshuffle: |
| 144 | + if b_transposed: |
| 145 | + for j, k, jj, kk in T.Parallel(block_N // micro_size_y, |
| 146 | + block_K // pack_size_k, micro_size_y, |
| 147 | + pack_size_k): |
| 148 | + B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j, |
| 149 | + ko * block_K // pack_size_k + k, jj, kk] |
| 150 | + else: |
| 151 | + for k, j, kk, jj in T.Parallel(block_K // pack_size_k, |
| 152 | + block_N // micro_size_y, pack_size_k, |
| 153 | + micro_size_y): |
| 154 | + B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k, |
| 155 | + bx * block_N // micro_size_y + j, kk, jj] |
| 156 | + else: |
| 157 | + if b_transposed: |
| 158 | + T.copy(B[bx * block_N, ko * block_K], B_shared) |
| 159 | + else: |
| 160 | + T.copy(B[ko * block_K, bx * block_N], B_shared) |
| 161 | + |
| 162 | + for ki in T.serial(0, (block_K // (k_pack * micro_size_k))): |
| 163 | + |
| 164 | + # Load A into fragment |
| 165 | + mfma_emitter.ldmatrix_a( |
| 166 | + A_local, |
| 167 | + A_shared, |
| 168 | + ki, |
| 169 | + ) |
| 170 | + |
| 171 | + # Load B into fragment |
| 172 | + mfma_emitter.ldmatrix_b( |
| 173 | + B_local, |
| 174 | + B_shared, |
| 175 | + ki, |
| 176 | + ) |
| 177 | + |
| 178 | + # Perform Matrix Multiplication |
| 179 | + mfma_emitter.mfma(A_local, B_local, C_local) |
| 180 | + |
| 181 | + # Perform STMatrix |
| 182 | + if cache_write_shared: |
| 183 | + mfma_emitter.stmatrix( |
| 184 | + C_local, |
| 185 | + C_shared, |
| 186 | + ) |
| 187 | + |
| 188 | + # Store shared into global |
| 189 | + for i, j in T.Parallel(block_M, block_N): |
| 190 | + C[by * block_M + i, bx * block_N + j] = C_shared[ |
| 191 | + i // micro_size_x, |
| 192 | + j // micro_size_y, |
| 193 | + i % micro_size_x, |
| 194 | + j % micro_size_y, |
| 195 | + ] |
| 196 | + else: |
| 197 | + mfma_emitter.stmatrix( |
| 198 | + C_local, |
| 199 | + C, |
| 200 | + pid_m=by, |
| 201 | + pid_n=bx, |
| 202 | + ) |
| 203 | + |
| 204 | + return main |
| 205 | + |
| 206 | + |
| 207 | +def shuffle_weight( |
| 208 | + x: torch.Tensor, |
| 209 | + layout=(16, 32), |
| 210 | + k_pack=1, |
| 211 | + is_transpose=False, |
| 212 | +) -> torch.Tensor: |
| 213 | + IN, IK = layout |
| 214 | + BK = IK * k_pack |
| 215 | + BN = IN |
| 216 | + |
| 217 | + N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) |
| 218 | + assert N % BN == 0 |
| 219 | + assert K % BK == 0 |
| 220 | + |
| 221 | + x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) |
| 222 | + x = x.permute(0, 2, 1, 3) |
| 223 | + return x.contiguous() |
| 224 | + |
| 225 | + |
| 226 | +def assert_tl_matmul_correctness(M, |
| 227 | + N, |
| 228 | + K, |
| 229 | + in_dtype, |
| 230 | + out_dtype, |
| 231 | + accum_dtype="float32", |
| 232 | + a_transposed=False, |
| 233 | + b_transposed=True, |
| 234 | + k_pack=1, |
| 235 | + b_preshuffle=False): |
| 236 | + matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, |
| 237 | + k_pack, b_preshuffle) |
| 238 | + print(matmul) |
| 239 | + kernel = tilelang.compile(matmul) |
| 240 | + src_code = kernel.get_kernel_source() |
| 241 | + # src_code is the generated cuda source |
| 242 | + assert src_code is not None |
| 243 | + A_shape = (K, M) if a_transposed else (M, K) |
| 244 | + B_shape = (N, K) if b_transposed else (K, N) |
| 245 | + if in_dtype == "int8": |
| 246 | + A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8) |
| 247 | + B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8) |
| 248 | + else: |
| 249 | + A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype)) |
| 250 | + B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype)) |
| 251 | + |
| 252 | + C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) |
| 253 | + |
| 254 | + B_preshuffle = B |
| 255 | + if b_preshuffle: |
| 256 | + B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed) |
| 257 | + kernel(A, B_preshuffle, C) |
| 258 | + else: |
| 259 | + kernel(A, B, C) |
| 260 | + |
| 261 | + print(kernel.get_kernel_source()) |
| 262 | + |
| 263 | + profiler = kernel.get_profiler() |
| 264 | + |
| 265 | + latency = profiler.do_bench() |
| 266 | + |
| 267 | + # Ensure that the latency is not None |
| 268 | + assert latency is not None |
| 269 | + |
| 270 | + if a_transposed and b_transposed: |
| 271 | + # Get Reference Result |
| 272 | + ref_c = torch.matmul(A.T.to(torch.float32), |
| 273 | + B.T.to(torch.float32)).to(getattr(torch, out_dtype)) |
| 274 | + elif a_transposed and not b_transposed: |
| 275 | + # Get Reference Result |
| 276 | + ref_c = torch.matmul(A.Tto(torch.float32), |
| 277 | + B.to(torch.float32)).to(getattr(torch, out_dtype)) |
| 278 | + elif not a_transposed and b_transposed: |
| 279 | + # Get Reference Result |
| 280 | + ref_c = torch.matmul(A.to(torch.float32), |
| 281 | + B.T.to(torch.float32)).to(getattr(torch, out_dtype)) |
| 282 | + else: |
| 283 | + # Get Reference Result |
| 284 | + ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) |
| 285 | + |
| 286 | + print(C) |
| 287 | + print(ref_c) |
| 288 | + torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2) |
| 289 | + |
| 290 | + |
| 291 | +@tilelang.testing.requires_rocm |
| 292 | +def test_assert_tl_matmul(): |
| 293 | + assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") |
| 294 | + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") |
| 295 | + assert_tl_matmul_correctness( |
| 296 | + 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32") |
| 297 | + assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) |
| 298 | + |
| 299 | + assert_tl_matmul_correctness( |
| 300 | + 128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True) |
| 301 | + assert_tl_matmul_correctness( |
| 302 | + 128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True) |
| 303 | + assert_tl_matmul_correctness( |
| 304 | + 128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True) |
| 305 | + |
| 306 | + assert_tl_matmul_correctness( |
| 307 | + 128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True) |
| 308 | + assert_tl_matmul_correctness( |
| 309 | + 128, |
| 310 | + 256, |
| 311 | + 256, |
| 312 | + "int8", |
| 313 | + "int32", |
| 314 | + b_transposed=False, |
| 315 | + accum_dtype="int32", |
| 316 | + k_pack=2, |
| 317 | + b_preshuffle=True) |
| 318 | + |
| 319 | + |
| 320 | +if __name__ == "__main__": |
| 321 | + tilelang.testing.main() |
0 commit comments