-
Couldn't load subscription status.
- Fork 286
[AMD] support mfma i32_16x16x32_i8 #800
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -8,6 +8,18 @@ namespace tl { | |||||||||
| // Trait to determine the MFMA instruction to use based on data type | ||||||||||
| template <typename T> struct MfmaTraits; | ||||||||||
|
|
||||||||||
| // Specialization for int8 | ||||||||||
| template <> struct MfmaTraits<int8_t> { | ||||||||||
| template <typename AccType> | ||||||||||
| static TL_DEVICE void mfma_op(const int8_t *b, const int8_t *a, AccType *c) { | ||||||||||
| int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b)); | ||||||||||
| int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a)); | ||||||||||
|
Comment on lines
+15
to
+16
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Potential undefined behavior: const_cast removes const qualifier inappropriately. Using Apply this diff to fix the const-correctness issue: - int64_t *b_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(b));
- int64_t *a_packed = reinterpret_cast<int64_t *>(const_cast<int8_t *>(a));
+ const int64_t *b_packed = reinterpret_cast<const int64_t *>(b);
+ const int64_t *a_packed = reinterpret_cast<const int64_t *>(a);📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||
|
|
||||||||||
| *c = __builtin_amdgcn_mfma_i32_16x16x32_i8(*b_packed, *a_packed, *c, 0, 0, | ||||||||||
| 0); | ||||||||||
| } | ||||||||||
| }; | ||||||||||
|
|
||||||||||
| // Specialization for half/float16 | ||||||||||
| template <> struct MfmaTraits<half> { | ||||||||||
| template <typename AccType> | ||||||||||
|
|
||||||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -41,7 +41,9 @@ def tl_matmul( | |
| block_col_warps = 2 | ||
| warp_row_tiles = 32 | ||
| warp_col_tiles = 32 | ||
| chunk = 32 | ||
|
|
||
| chunk = 32 * k_pack | ||
|
|
||
| shared_scope = "shared" | ||
| cache_write_shared = False | ||
|
|
||
|
|
@@ -193,6 +195,7 @@ def assert_tl_matmul_correctness(M, | |
| C = torch.zeros(M, N, device="cuda", dtype=getattr(torch, out_dtype)) | ||
|
|
||
| kernel(A, B, C) | ||
| print(kernel.get_kernel_source()) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
|
||
| profiler = kernel.get_profiler() | ||
|
|
||
|
|
@@ -227,6 +230,9 @@ def test_assert_tl_matmul(): | |
| assert_tl_matmul_correctness(128, 128, 128, "float16", "float16") | ||
| assert_tl_matmul_correctness(128, 256, 256, "float16", "float32") | ||
| assert_tl_matmul_correctness(128, 256, 256, "float16", "float32", k_pack=2) | ||
| assert_tl_matmul_correctness(128, 128, 128, "int8", "int32", accum_dtype="int32") | ||
| assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32") | ||
| assert_tl_matmul_correctness(128, 256, 256, "int8", "int32", accum_dtype="int32", k_pack=2) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The use of
const_casthere is unnecessary and not considered best practice. Since the data pointed to byaandbis not modified, you can directlyreinterpret_castto aconstpointer type. This preservesconstcorrectness and improves code safety.