Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions src/target/codegen_hip.cc
Original file line number Diff line number Diff line change
Expand Up @@ -880,7 +880,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
os << "]" << ((i < 3) ? ", " : ")");
}
} else if (op->op.same_as(tl::tvm_mfma())) {
// arg 0: prefix: {otype}_16x16x16{itype}
// arg 0: prefix: {otype}_{intrM}x{intrN}x{intrK}_{itype}
// arg 1: A layout: row/col
// arg 2: B layout: row/col
// arg 3: A precision: float16, float32, ...
Expand Down Expand Up @@ -914,6 +914,7 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"int8", "char"},
{"int32", "int"},
{"int8x4", "int32_t"},
{"int8x8", "int64_t"},
{"int32x4", "int32x4"},
{"float16", "half"},
{"float32", "float"},
Expand All @@ -925,17 +926,17 @@ void CodeGenTileLangHIP::VisitExpr_(const CallNode *op, std::ostream &os) {
{"float8_e4m3fnuzx8", "long"},
{"float32x16", "float32x16"}};
std::string call_mfma_code = R"({
*((({C_dytpe}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dytpe}*){a_ref}) + {a_bias}),
*((({B_dytpe}*){b_ref}) + {b_bias}),
*((({C_dytpe}*){c_ref}) + {c_bias}), 0, 0, 0);
*((({C_dtype}*){c_ref}) + {c_bias}) = {mfma_buildin}(*((({A_dtype}*){a_ref}) + {a_bias}),
*((({B_dtype}*){b_ref}) + {b_bias}),
*((({C_dtype}*){c_ref}) + {c_bias}), 0, 0, 0);
})";
std::string mfma_buildin = "__builtin_amdgcn_mfma_" + prefix;
Replacer replacer;

replacer.register_rule("{mfma_buildin}", mfma_buildin);
replacer.register_rule("{A_dytpe}", dtype_map[A_dtype]);
replacer.register_rule("{B_dytpe}", dtype_map[B_dtype]);
replacer.register_rule("{C_dytpe}", dtype_map[C_dtype]);
replacer.register_rule("{A_dtype}", dtype_map[A_dtype]);
replacer.register_rule("{B_dtype}", dtype_map[B_dtype]);
replacer.register_rule("{C_dtype}", dtype_map[C_dtype]);
replacer.register_rule("{a_ref}", a_ref);
replacer.register_rule("{a_bias}", a_bias);
replacer.register_rule("{b_ref}", b_ref);
Expand Down
12 changes: 12 additions & 0 deletions src/tl_templates/hip/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@ namespace tl {
// Trait to determine the MFMA instruction to use based on data type
template <typename T> struct MfmaTraits;

// Specialization for int8
template <> struct MfmaTraits<int8_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) {
int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
Comment on lines +15 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The use of const_cast here is unnecessary and not considered best practice. Since the data pointed to by a and b is not modified, you can directly reinterpret_cast to a const pointer type. This preserves const correctness and improves code safety.

    const int64_t *b_packed = reinterpret_cast<const int64_t *>(b);
    const int64_t *a_packed = reinterpret_cast<const int64_t *>(a);

Comment on lines +15 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Potential undefined behavior: const_cast removes const qualifier inappropriately.

Using const_cast to remove the const qualifier from const int8_t* parameters and then modifying the memory through reinterpret_cast could lead to undefined behavior if the underlying memory is actually const. The pointers should be cast directly without removing const.

Apply this diff to fix the const-correctness issue:

-    int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
-    int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
+    const int64_t *b_packed = reinterpret_cast<const int64_t *>(b);
+    const int64_t *a_packed = reinterpret_cast<const int64_t *>(a);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
const int64_t *b_packed = reinterpret_cast<const int64_t *>(b);
const int64_t *a_packed = reinterpret_cast<const int64_t *>(a);
🤖 Prompt for AI Agents
In src/tl_templates/hip/gemm.h around lines 15 to 16, the code uses const_cast
to strip const from const int8_t* and then reinterpret_cast to int64_t*, which
can invoke undefined behavior; change the target pointer types to const int64_t*
and cast directly from the original const int8_t* using reinterpret_cast<const
int64_t*>(...), removing any const_cast so the const qualifier is preserved and
you don't attempt to modify potentially const memory.


*c = __builtin_amdgcn_mfma_i32_16x16x32_i8(*b_packed, *a_packed, *c, 0, 0,
0);
}
};

// Specialization for half/float16
template <> struct MfmaTraits<half> {
template <typename AccType>
Expand Down
8 changes: 7 additions & 1 deletion testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ def tl_matmul(
block_col_warps = 2
warp_row_tiles = 32
warp_col_tiles = 32
chunk = 32

chunk = 32 * k_pack

shared_scope = "shared"
cache_write_shared = False

Expand Down Expand Up @@ -193,6 +195,7 @@ def assert_tl_matmul_correctness(M,
C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype))

kernel(A, B, C)
print(kernel.get_kernel_source())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This print statement appears to be for debugging. It should be removed before merging to keep the test output clean.


profiler = kernel.get_profiler()

Expand Down Expand Up @@ -227,6 +230,9 @@ def test_assert_tl_matmul():
assert_tl_matmul_correctness(128, 128, 128, "float16", "float16")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32")
assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2)
assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32")
assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2)


if __name__ == "__main__":
Expand Down
4 changes: 3 additions & 1 deletion tilelang/intrinsics/mfma_macro_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(

def _initialize_k_dim(self, a_dtype="float16"):
if isinstance(a_dtype, str):
if a_dtype in ["float8_e4m3fnuz"]:
if a_dtype in ["float8_e4m3fnuz", "int8"]:
self.k_dim = 32
return
a_dtype = DataType(a_dtype)
Expand Down Expand Up @@ -123,6 +123,8 @@ def _initialize_mfma_prefix(self, k_dim=16):

if in_dtype_abbrv == "fp8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_fp8_fp8"
elif in_dtype_abbrv == "i8":
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}_i8"
else:
self.mfma_suffix = f"{out_dtype_abbrv}_{M_DIM}x{N_DIM}x{k_dim}{in_dtype_abbrv}"

Expand Down
Loading