Skip to content

Commit 1f46d89

Browse files
committed
[AMD] support preshuffle weight mfma
1 parent 409ab83 commit 1f46d89

File tree

2 files changed

+371
-19
lines changed

2 files changed

+371
-19
lines changed
Lines changed: 321 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,321 @@
1+
import torch
2+
import tilelang.testing
3+
from tilelang import tvm as tvm
4+
import tilelang.language as T
5+
from tilelang.intrinsics import make_mfma_swizzle_layout as make_swizzle_layout
6+
from tilelang.intrinsics.mfma_macro_generator import (
7+
MatrixCoreIntrinEmitter,)
8+
from tilelang.transform import simplify_prim_func
9+
10+
tilelang.testing.set_random_seed(0)
11+
12+
13+
@simplify_prim_func
14+
def tl_matmul(
15+
M,
16+
N,
17+
K,
18+
in_dtype,
19+
out_dtype,
20+
accum_dtype,
21+
a_transposed=False,
22+
b_transposed=True,
23+
k_pack=1,
24+
b_preshuffle=False,
25+
):
26+
assert in_dtype in [
27+
"float16",
28+
"int8",
29+
], "Currently only float16 and int8 are supported"
30+
assert out_dtype in [
31+
"float16",
32+
"float32",
33+
"int32",
34+
], "Currently only float16, float32 and int32 are supported"
35+
36+
micro_size_x = micro_size_y = micro_size_k = 16
37+
38+
if in_dtype in {"float8_e4m3fnuz", "int8"}:
39+
micro_size_k = 32
40+
41+
block_row_warps = 2
42+
block_col_warps = 2
43+
warp_row_tiles = 32
44+
warp_col_tiles = 32
45+
46+
# for preshuffle_b, warp_layout = {1, 4}
47+
if b_preshuffle:
48+
block_row_warps = 1
49+
block_col_warps = 4
50+
warp_row_tiles = 128
51+
warp_col_tiles = 32
52+
53+
chunk = 32 * k_pack
54+
55+
pack_size_k = micro_size_k * k_pack
56+
57+
shared_scope = "shared"
58+
cache_write_shared = False
59+
60+
block_M = block_row_warps * warp_row_tiles
61+
block_N = block_col_warps * warp_col_tiles
62+
block_K = chunk
63+
64+
A_shape = (K, M) if a_transposed else (M, K)
65+
if b_preshuffle:
66+
B_shape = (N // micro_size_y, K // pack_size_k, micro_size_y,
67+
pack_size_k) if b_transposed else (K // pack_size_k, N // micro_size_y,
68+
pack_size_k, micro_size_y)
69+
else:
70+
B_shape = (N, K) if b_transposed else (K, N)
71+
A_shared_shape = (block_K, block_M) if a_transposed else (block_M, block_K)
72+
if b_preshuffle:
73+
B_shared_shape = (block_N // micro_size_y, block_K // pack_size_k, micro_size_y,
74+
pack_size_k) if b_transposed else (block_K // pack_size_k,
75+
block_N // micro_size_y, pack_size_k,
76+
micro_size_y)
77+
else:
78+
B_shared_shape = (block_N, block_K) if b_transposed else (block_K, block_N)
79+
C_shared_shape = (
80+
block_M // micro_size_x,
81+
block_N // micro_size_y,
82+
micro_size_x,
83+
micro_size_y,
84+
)
85+
86+
warp_size = 64
87+
threads = warp_size * (block_row_warps * block_col_warps)
88+
local_size_a = (k_pack * micro_size_x * micro_size_k) // warp_size
89+
local_size_b = (k_pack * micro_size_y * micro_size_k) // warp_size
90+
local_size_c = (micro_size_x * micro_size_y) // warp_size
91+
warp_rows = warp_row_tiles // micro_size_x
92+
warp_cols = warp_col_tiles // micro_size_y
93+
94+
# MMA Wrapper to Auto Generate Code for MMA
95+
mfma_emitter = MatrixCoreIntrinEmitter(
96+
a_dtype=in_dtype,
97+
b_dtype=in_dtype,
98+
accum_dtype=accum_dtype,
99+
a_transposed=a_transposed,
100+
b_transposed=b_transposed,
101+
block_row_warps=block_row_warps,
102+
block_col_warps=block_col_warps,
103+
warp_row_tiles=warp_row_tiles,
104+
warp_col_tiles=warp_col_tiles,
105+
chunk=chunk,
106+
k_pack=k_pack,
107+
b_preshuffle=b_preshuffle,
108+
)
109+
110+
@T.prim_func
111+
def main(
112+
A: T.Tensor(A_shape, in_dtype),
113+
B: T.Tensor(B_shape, in_dtype),
114+
C: T.Tensor((M, N), out_dtype),
115+
):
116+
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
117+
118+
A_shared = T.alloc_shared(A_shared_shape, in_dtype, scope=shared_scope)
119+
B_shared = T.alloc_shared(B_shared_shape, in_dtype, scope=shared_scope)
120+
C_shared = T.alloc_shared(C_shared_shape, out_dtype, scope=shared_scope)
121+
A_local = T.alloc_local((warp_rows * local_size_a), in_dtype)
122+
B_local = T.alloc_local((warp_cols * local_size_b), in_dtype)
123+
C_local = T.alloc_local((warp_rows * warp_cols * local_size_c), accum_dtype)
124+
125+
T.annotate_layout({
126+
A_shared: make_swizzle_layout(A_shared),
127+
})
128+
129+
# Improve L2 Cache
130+
T.use_swizzle(panel_size=10)
131+
132+
T.clear(C_local)
133+
134+
for ko in T.Pipelined((K // block_K), num_stages=0):
135+
136+
# Load A into shared memory
137+
if a_transposed:
138+
T.copy(A[ko * block_K, by * block_M], A_shared)
139+
else:
140+
T.copy(A[by * block_M, ko * block_K], A_shared)
141+
142+
# Load B into shared memory
143+
if b_preshuffle:
144+
if b_transposed:
145+
for j, k, jj, kk in T.Parallel(block_N // micro_size_y,
146+
block_K // pack_size_k, micro_size_y,
147+
pack_size_k):
148+
B_shared[j, k, jj, kk] = B[bx * block_N // micro_size_y + j,
149+
ko * block_K // pack_size_k + k, jj, kk]
150+
else:
151+
for k, j, kk, jj in T.Parallel(block_K // pack_size_k,
152+
block_N // micro_size_y, pack_size_k,
153+
micro_size_y):
154+
B_shared[k, j, kk, jj] = B[ko * block_K // pack_size_k + k,
155+
bx * block_N // micro_size_y + j, kk, jj]
156+
else:
157+
if b_transposed:
158+
T.copy(B[bx * block_N, ko * block_K], B_shared)
159+
else:
160+
T.copy(B[ko * block_K, bx * block_N], B_shared)
161+
162+
for ki in T.serial(0, (block_K // (k_pack * micro_size_k))):
163+
164+
# Load A into fragment
165+
mfma_emitter.ldmatrix_a(
166+
A_local,
167+
A_shared,
168+
ki,
169+
)
170+
171+
# Load B into fragment
172+
mfma_emitter.ldmatrix_b(
173+
B_local,
174+
B_shared,
175+
ki,
176+
)
177+
178+
# Perform Matrix Multiplication
179+
mfma_emitter.mfma(A_local, B_local, C_local)
180+
181+
# Perform STMatrix
182+
if cache_write_shared:
183+
mfma_emitter.stmatrix(
184+
C_local,
185+
C_shared,
186+
)
187+
188+
# Store shared into global
189+
for i, j in T.Parallel(block_M, block_N):
190+
C[by * block_M + i, bx * block_N + j] = C_shared[
191+
i // micro_size_x,
192+
j // micro_size_y,
193+
i % micro_size_x,
194+
j % micro_size_y,
195+
]
196+
else:
197+
mfma_emitter.stmatrix(
198+
C_local,
199+
C,
200+
pid_m=by,
201+
pid_n=bx,
202+
)
203+
204+
return main
205+
206+
207+
def shuffle_weight(
208+
x: torch.Tensor,
209+
layout=(16, 32),
210+
k_pack=1,
211+
is_transpose=False,
212+
) -> torch.Tensor:
213+
IN, IK = layout
214+
BK = IK * k_pack
215+
BN = IN
216+
217+
N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
218+
assert N % BN == 0
219+
assert K % BK == 0
220+
221+
x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
222+
x = x.permute(0, 2, 1, 3)
223+
return x.contiguous()
224+
225+
226+
def assert_tl_matmul_correctness(M,
227+
N,
228+
K,
229+
in_dtype,
230+
out_dtype,
231+
accum_dtype="float32",
232+
a_transposed=False,
233+
b_transposed=True,
234+
k_pack=1,
235+
b_preshuffle=False):
236+
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
237+
k_pack, b_preshuffle)
238+
print(matmul)
239+
kernel = tilelang.compile(matmul)
240+
src_code = kernel.get_kernel_source()
241+
# src_code is the generated cuda source
242+
assert src_code is not None
243+
A_shape = (K, M) if a_transposed else (M, K)
244+
B_shape = (N, K) if b_transposed else (K, N)
245+
if in_dtype == "int8":
246+
A = torch.randint(-128, 127, A_shape, device="cuda", dtype=torch.int8)
247+
B = torch.randint(-128, 127, B_shape, device="cuda", dtype=torch.int8)
248+
else:
249+
A = torch.rand(A_shape, device="cuda", dtype=getattr(torch, in_dtype))
250+
B = torch.rand(B_shape, device="cuda", dtype=getattr(torch, in_dtype))
251+
252+
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
253+
254+
B_preshuffle = B
255+
if b_preshuffle:
256+
B_preshuffle = shuffle_weight(B_preshuffle, k_pack=k_pack, is_transpose=b_transposed)
257+
kernel(A, B_preshuffle, C)
258+
else:
259+
kernel(A, B, C)
260+
261+
print(kernel.get_kernel_source())
262+
263+
profiler = kernel.get_profiler()
264+
265+
latency = profiler.do_bench()
266+
267+
# Ensure that the latency is not None
268+
assert latency is not None
269+
270+
if a_transposed and b_transposed:
271+
# Get Reference Result
272+
ref_c = torch.matmul(A.T.to(torch.float32),
273+
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
274+
elif a_transposed and not b_transposed:
275+
# Get Reference Result
276+
ref_c = torch.matmul(A.Tto(torch.float32),
277+
B.to(torch.float32)).to(getattr(torch, out_dtype))
278+
elif not a_transposed and b_transposed:
279+
# Get Reference Result
280+
ref_c = torch.matmul(A.to(torch.float32),
281+
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
282+
else:
283+
# Get Reference Result
284+
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
285+
286+
print(C)
287+
print(ref_c)
288+
torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
289+
290+
291+
@tilelang.testing.requires_rocm
292+
def test_assert_tl_matmul():
293+
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
294+
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
295+
assert_tl_matmul_correctness(
296+
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32")
297+
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
298+
299+
assert_tl_matmul_correctness(
300+
128, 128, 128, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
301+
assert_tl_matmul_correctness(
302+
128, 256, 256, "int8", "int32", accum_dtype="int32", b_preshuffle=True)
303+
assert_tl_matmul_correctness(
304+
128, 256, 256, "int8", "int32", b_transposed=False, accum_dtype="int32", b_preshuffle=True)
305+
306+
assert_tl_matmul_correctness(
307+
128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2, b_preshuffle=True)
308+
assert_tl_matmul_correctness(
309+
128,
310+
256,
311+
256,
312+
"int8",
313+
"int32",
314+
b_transposed=False,
315+
accum_dtype="int32",
316+
k_pack=2,
317+
b_preshuffle=True)
318+
319+
320+
if __name__ == "__main__":
321+
tilelang.testing.main()

0 commit comments

Comments
 (0)