Skip to content

Commit 492a0c4

Browse files
committed
[AMD] Supoort T.gemm_v2 for AMD Backend
1 parent 6e1dc6a commit 492a0c4

File tree

5 files changed

+1135
-35
lines changed

5 files changed

+1135
-35
lines changed
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import tilelang.language as T
2+
from typing import Literal, Callable
3+
from tvm.tir import IndexMap
4+
from tilelang.intrinsics.utils import get_mma_micro_size
5+
6+
from tilelang.intrinsics.mfma_layout import (
7+
shared_16x4_to_local_64x1_layout_A,
8+
shared_16x16_to_local_64x4_layout_A,
9+
shared_16x32_to_local_64x8_layout_A,
10+
shared_16x64_to_local_64x16_layout_A,
11+
)
12+
13+
14+
def make_mfma_load_base_layout(dtype: str = "float16",
15+
matrix: Literal["A", "B"] = "A",
16+
k_dim: int = 16,
17+
transposed: bool = False) -> T.Fragment:
18+
"""
19+
Create a layout function for storing MFMA results into a fragment buffer.
20+
This layout is used in conjunction with `inverse_mfma_store_layout` to
21+
map fragment indices to threads and local indices.
22+
23+
Parameters
24+
----------
25+
dtype : str
26+
The data type of the matrix.
27+
matrix : Literal["A", "B"]
28+
The mfma operand to be loaded.
29+
k_dim : int
30+
The k dimension of the mfma.
31+
transposed : bool
32+
Whether the matrix is transposed, by default False.
33+
34+
Returns
35+
-------
36+
T.Fragment
37+
Describes how threads and indices in fragment are laid out.
38+
39+
"""
40+
41+
assert matrix in ["A", "B"], "matrix should be either A or B"
42+
# s represents spatial axis
43+
# r represents reduction axis
44+
# sr represents the two dims are spatial + reduction
45+
# rs represents the two dims are reduction + spatial
46+
transform_func_sr_a: Callable = None
47+
transform_func_sr_b: Callable = None
48+
49+
if k_dim == 4:
50+
transform_func_sr_a = shared_16x4_to_local_64x1_layout_A
51+
transform_func_sr_b = shared_16x4_to_local_64x1_layout_A
52+
elif k_dim == 16:
53+
transform_func_sr_a = shared_16x16_to_local_64x4_layout_A
54+
transform_func_sr_b = shared_16x16_to_local_64x4_layout_A
55+
elif k_dim == 32:
56+
transform_func_sr_a = shared_16x32_to_local_64x8_layout_A
57+
transform_func_sr_b = shared_16x32_to_local_64x8_layout_A
58+
elif k_dim == 64:
59+
transform_func_sr_a = shared_16x64_to_local_64x16_layout_A
60+
transform_func_sr_b = shared_16x64_to_local_64x16_layout_A
61+
else:
62+
raise ValueError("k_dim must be 4 or 16 or 32 or 64 currently")
63+
64+
is_sr_conditions = [False]
65+
is_sr_conditions.append(matrix == "A" and not transposed)
66+
is_sr_conditions.append(matrix == "B" and transposed)
67+
is_sr_axis_order = any(is_sr_conditions)
68+
69+
micro_size_x, micro_size_y, micro_size_k = get_mma_micro_size(dtype)
70+
71+
# the layout of mma.sync is row.col.
72+
# so the b matrix expected a transposed basic layout
73+
transform_func: Callable = None
74+
if matrix == "A":
75+
transform_func = transform_func_sr_a if is_sr_axis_order else lambda i, j: transform_func_sr_a(
76+
j, i)
77+
micro_size_s, micro_size_r = micro_size_x, micro_size_k
78+
elif matrix == "B":
79+
transform_func = transform_func_sr_b if is_sr_axis_order else lambda i, j: transform_func_sr_b(
80+
j, i)
81+
micro_size_s, micro_size_r = micro_size_k, micro_size_y
82+
else:
83+
raise ValueError(f"Unsupported matrix {matrix}")
84+
85+
inverse_mma_load_layout = IndexMap.from_func(transform_func, index_dtype="int32")
86+
87+
def forward_thread(i: int, j: int) -> int:
88+
"""
89+
Given the row index `i` and column index `j` in the fragment,
90+
"""
91+
lane_id, _ = inverse_mma_load_layout.map_indices([i, j])
92+
return lane_id
93+
94+
def forward_index(i: int, j: int) -> int:
95+
"""
96+
Given the row index `i` and column index `j` in the fragment,
97+
"""
98+
_, local_id = inverse_mma_load_layout.map_indices([i, j])
99+
return local_id
100+
101+
base_fragment = T.Fragment(
102+
[micro_size_s, micro_size_r] if is_sr_axis_order else [micro_size_r, micro_size_s],
103+
forward_thread_fn=forward_thread,
104+
forward_index_fn=forward_index,
105+
)
106+
return base_fragment
107+
108+
109+
block_rows = 2
110+
block_cols = 2
111+
warp_rows = 2
112+
warp_cols = 2
113+
chunk = 2
114+
115+
from tilelang.tools import plot_layout
116+
117+
# ldmatrix layout 16x16
118+
base_layout = make_mfma_load_base_layout(dtype="float16", matrix="A", transposed=False)
119+
print(base_layout)
120+
plot_layout(base_layout, name="base_layout")
121+
122+
# warp layout 32x32
123+
warp_layout = base_layout.repeat([warp_rows, warp_cols],
124+
repeat_on_thread=False,
125+
lower_dim_first=False)
126+
print(warp_layout)
127+
plot_layout(warp_layout, name="warp_layout")
128+
129+
# block layout 64x32
130+
block_layout = warp_layout.repeat([block_rows, 1], repeat_on_thread=True,
131+
lower_dim_first=True).replicate(block_cols)
132+
print(block_layout)
133+
plot_layout(block_layout, name="block_layout")

0 commit comments

Comments
 (0)