Skip to content

Commit 9fd6bb3

Browse files
authored
[AMD] support mfma i32_16x16x32_i8 (#800)
Co-authored-by: Jiaxing Ding <jiaxing.ding@bytedance.com>
1 parent 54aaec9 commit 9fd6bb3

File tree

4 files changed

+30
-9
lines changed

4 files changed

+30
-9
lines changed

src/target/codegen_hip.cc

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -880,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
880880
os << "]" << ((i < 3) ? ", " : ")");
881881
}
882882
} else if (op->op.same_as(tl::tvm_mfma())) {
883-
// arg 0: prefix: {otype}_16x16x16{itype}
883+
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
884884
// arg 1: A layout: row/col
885885
// arg 2: B layout: row/col
886886
// arg 3: A precision: float16, float32, ...
@@ -914,6 +914,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
914914
{"int8", "char"},
915915
{"int32", "int"},
916916
{"int8x4", "int32_t"},
917+
{"int8x8", "int64_t"},
917918
{"int32x4", "int32x4"},
918919
{"float16", "half"},
919920
{"float32", "float"},
@@ -925,17 +926,17 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
925926
{"float8_e4m3fnuzx8", "long"},
926927
{"float32x16", "float32x16"}};
927928
std::string call_mfma_code = R"({
928-
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
929-
*((({B_dytpe}*){b_ref}) + {b_bias}),
930-
*((({C_dytpe}*){c_ref}) + {c_bias}), 0, 0, 0);
929+
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
930+
*((({B_dtype}*){b_ref}) + {b_bias}),
931+
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
931932
})";
932933
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
933934
Replacer replacer;
934935

935936
replacer.register_rule("{mfma_buildin}", mfma_buildin);
936-
replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]);
937-
replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]);
938-
replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]);
937+
replacer.register_rule("{A_dtype}", dtype_map[A_dtype]);
938+
replacer.register_rule("{B_dtype}", dtype_map[B_dtype]);
939+
replacer.register_rule("{C_dtype}", dtype_map[C_dtype]);
939940
replacer.register_rule("{a_ref}", a_ref);
940941
replacer.register_rule("{a_bias}", a_bias);
941942
replacer.register_rule("{b_ref}", b_ref);

src/tl_templates/hip/gemm.h

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,18 @@ namespace tl {
88
// Trait to determine the MFMA instruction to use based on data type
99
template <typename T> struct MfmaTraits;
1010

11+
// Specialization for int8
12+
template <> struct MfmaTraits<int8_t> {
13+
template <typename AccType>
14+
static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) {
15+
int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
16+
int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
17+
18+
*c = __builtin_amdgcn_mfma_i32_16x16x32_i8(*b_packed, *a_packed, *c, 0, 0,
19+
0);
20+
}
21+
};
22+
1123
// Specialization for half/float16
1224
template <> struct MfmaTraits<half> {
1325
template <typename AccType>

testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,9 @@ def tl_matmul(
4141
block_col_warps = 2
4242
warp_row_tiles = 32
4343
warp_col_tiles = 32
44-
chunk = 32
44+
45+
chunk = 32 * k_pack
46+
4547
shared_scope = "shared"
4648
cache_write_shared = False
4749

@@ -193,6 +195,7 @@ def assert_tl_matmul_correctness(M,
193195
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))
194196

195197
kernel(A, B, C)
198+
print(kernel.get_kernel_source())
196199

197200
profiler = kernel.get_profiler()
198201

@@ -227,6 +230,9 @@ def test_assert_tl_matmul():
227230
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
228231
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
229232
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
233+
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
234+
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
235+
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)
230236

231237

232238
if __name__ == "__main__":

tilelang/intrinsics/mfma_macro_generator.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ def __init__(
8181

8282
def _initialize_k_dim(self, a_dtype="float16"):
8383
if isinstance(a_dtype, str):
84-
if a_dtype in ["float8_e4m3fnuz"]:
84+
if a_dtype in ["float8_e4m3fnuz", "int8"]:
8585
self.k_dim = 32
8686
return
8787
a_dtype = DataType(a_dtype)
@@ -123,6 +123,8 @@ def _initialize_mfma_prefix(self, k_dim=16):
123123

124124
if in_dtype_abbrv == "fp8":
125125
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
126+
elif in_dtype_abbrv == "i8":
127+
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
126128
else:
127129
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"
128130

0 commit comments

Comments
 (0)