-
Notifications
You must be signed in to change notification settings - Fork 334
[AMD] support preshuffle weight mfma #806
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[AMD] support preshuffle weight mfma #806
Conversation
WalkthroughAdds a new ROCm test module to validate an AMD MFMA GEMM path supporting B-presuffle layouts and integrates a b_preshuffle code path into MatrixCoreIntrinEmitter, modifying its initialization and B-tile load logic. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant Test as PyTest (ROCm)
participant TL as tl_matmul Builder
participant Emitter as MatrixCoreIntrinEmitter
participant Kernel as GEMM Kernel
participant Torch as PyTorch Ref
Test->>TL: build tl_matmul(M,N,K, dtypes, flags: b_preshuffle, k_pack, transposes)
TL->>Emitter: __init__(..., b_preshuffle)
Emitter-->>TL: configured emitter
Note over Emitter,Kernel: Emitter selects B load path<br/>(4D preshuffle vs 2D standard)
Test->>Test: prepare tensors (optional shuffle_weight on B)
Test->>Kernel: run kernel(A, B[, preshuffled], C)
Kernel->>Kernel: load A/B tiles (B uses preshuffle path if enabled)
Kernel->>Kernel: MFMA accumulate and store C
Test->>Torch: compute reference gemm (with transposes)
Test->>Test: compare outputs, benchmark
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Pre-merge checks (2 passed, 1 warning)❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. ✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @Paran0idy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request introduces support for preshuffled weights within AMD's Matrix Fused Multiply-Add (MFMA) operations. This enhancement allows for optimized handling of specific weight layouts, improving performance for relevant matrix multiplication workloads.
Highlights
- Preshuffle Weight Support: Implemented the capability to handle preshuffled weights for AMD MFMA operations.
- MFMA Emitter Update: The MatrixCoreIntrinEmitter now includes a b_preshuffle parameter to control the new behavior.
- Optimized B Tensor Loading: Modified the _warp_ldmatrix_b function to correctly load 4-dimensional B tensors when preshuffling is enabled.
- New Test Coverage: A new test file test_tilelang_gemm_mfma_preshuffle.py has been added to thoroughly validate the correctness and functionality of the preshuffle weight feature.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces support for pre-shuffled weights for AMD MFMA, a valuable performance optimization. The changes include a new test file for validation and updates to the MFMA macro generator. My review has identified a critical bug within the new test code, specifically a typo that affects correctness checking for transposed matrices. I have also provided suggestions to improve code maintainability by refactoring duplicated logic in the macro generator and by decoupling a helper function from hardcoded kernel parameters. Addressing these points will enhance the quality and robustness of the implementation.
| if a_transposed and b_transposed: | ||
| # Get Reference Result | ||
| ref_c = torch.matmul(A.T.to(torch.float32), | ||
| B.T.to(torch.float32)).to(getattr(torch, out_dtype)) | ||
| elif a_transposed and not b_transposed: | ||
| # Get Reference Result | ||
| ref_c = torch.matmul(A.Tto(torch.float32), | ||
| B.to(torch.float32)).to(getattr(torch, out_dtype)) | ||
| elif not a_transposed and b_transposed: | ||
| # Get Reference Result | ||
| ref_c = torch.matmul(A.to(torch.float32), | ||
| B.T.to(torch.float32)).to(getattr(torch, out_dtype)) | ||
| else: | ||
| # Get Reference Result | ||
| ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's a typo on line 276: A.Tto(torch.float32) should be A.T.to(torch.float32).
Although this code path is not exercised by the current tests (no test case sets a_transposed=True), this is a critical bug that should be fixed.
Additionally, this entire conditional block can be simplified to improve readability and maintainability, which would also help prevent such typos.
A_ref = A.T if a_transposed else A
B_ref = B.T if b_transposed else B
ref_c = torch.matmul(A_ref.to(torch.float32), B_ref.to(torch.float32)).to(getattr(torch, out_dtype))| def shuffle_weight( | ||
| x: torch.Tensor, | ||
| layout=(16, 32), | ||
| k_pack=1, | ||
| is_transpose=False, | ||
| ) -> torch.Tensor: | ||
| IN, IK = layout |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The shuffle_weight function uses a hardcoded default layout=(16, 32). These values are tightly coupled with micro_size_y and micro_size_k from the tl_matmul kernel definition. This makes the function less flexible and could lead to errors if the kernel parameters change.
To improve modularity and reduce this coupling, consider passing micro_size_n and micro_size_k directly to the function instead of the layout tuple.
You would then update the call site in assert_tl_matmul_correctness like this:
if b_preshuffle:
micro_size_k = 32 if in_dtype == "int8" else 16
micro_size_y = 16
B_preshuffle = shuffle_weight(
B_preshuffle,
micro_size_n=micro_size_y,
micro_size_k=micro_size_k,
k_pack=k_pack,
is_transpose=b_transposed
)
kernel(A, B_preshuffle, C)| def shuffle_weight( | |
| x: torch.Tensor, | |
| layout=(16, 32), | |
| k_pack=1, | |
| is_transpose=False, | |
| ) -> torch.Tensor: | |
| IN, IK = layout | |
| def shuffle_weight( | |
| x: torch.Tensor, | |
| micro_size_n: int, | |
| micro_size_k: int, | |
| k_pack: int = 1, | |
| is_transpose: bool = False, | |
| ) -> torch.Tensor: | |
| IN, IK = micro_size_n, micro_size_k |
| if self.b_preshuffle: | ||
| if is_transposed: | ||
| for j in T.serial(warp_cols): | ||
| for local_id in T.vectorized(k_pack * local_size_b): | ||
| row, col = T.meta_var(reverse_index_map(tx, local_id)) | ||
| l, r = ( | ||
| warp_n * warp_cols + j, | ||
| rk * (chunk // micro_size_k) + ki, | ||
| ) | ||
| B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, | ||
| row, | ||
| col] | ||
| else: | ||
| for j in T.serial(warp_cols): | ||
| for local_id in T.vectorized(k_pack * local_size_b): | ||
| row, col = T.meta_var(reverse_index_map(tx, local_id)) | ||
| l, r = ( | ||
| rk * (chunk // micro_size_k) + ki, | ||
| warp_n * warp_cols + j, | ||
| ) | ||
| B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, | ||
| row, | ||
| col] | ||
| else: | ||
| for j in T.serial(warp_cols): | ||
| for local_id in T.vectorized(k_pack * local_size_b): | ||
| row, col = T.meta_var(reverse_index_map(tx, local_id)) | ||
| l, r = ( | ||
| rk * chunk + ki * (k_pack * micro_size_k), | ||
| warp_n * warp_col_tiles + j * micro_size_y, | ||
| ) | ||
| B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, | ||
| r + col] | ||
| if is_transposed: | ||
| for j in T.serial(warp_cols): | ||
| for local_id in T.vectorized(k_pack * local_size_b): | ||
| row, col = T.meta_var(reverse_index_map(tx, local_id)) | ||
| l, r = ( | ||
| warp_n * warp_col_tiles + j * micro_size_y, | ||
| rk * chunk + ki * (k_pack * micro_size_k), | ||
| ) | ||
| B_local_buf[j * k_pack * local_size_b + | ||
| local_id] = B_shared_buf[l + row, r + col] | ||
| else: | ||
| for j in T.serial(warp_cols): | ||
| for local_id in T.vectorized(k_pack * local_size_b): | ||
| row, col = T.meta_var(reverse_index_map(tx, local_id)) | ||
| l, r = ( | ||
| rk * chunk + ki * (k_pack * micro_size_k), | ||
| warp_n * warp_col_tiles + j * micro_size_y, | ||
| ) | ||
| B_local_buf[j * k_pack * local_size_b + | ||
| local_id] = B_shared_buf[l + row, r + col] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The logic in _warp_ldmatrix_b is split into four branches based on self.b_preshuffle and is_transposed, leading to significant code duplication. The loop structure, reverse_index_map call, and buffer assignment are repeated in each branch.
This duplication makes the code harder to read and maintain. Please consider refactoring to consolidate the common logic and isolate the parts that truly differ. For example, you could move the conditional logic for index calculation inside the loops.
Here's a possible refactoring approach:
@T.macro
def _warp_ldmatrix_b(...):
tx, warp_n, _ = self.extract_thread_binding(thread_binding)
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
if self.b_preshuffle:
if is_transposed:
l, r = warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki
else:
l, r = rk * (chunk // micro_size_k) + ki, warp_n * warp_cols + j
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
else:
if is_transposed:
l, r = warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k)
else:
l, r = rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, r + col]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🧹 Nitpick comments (4)
tilelang/intrinsics/mfma_macro_generator.py (2)
297-320: Consider renaming ambiguous variables for better readability.While
landrare functional variable names, they can be visually ambiguous (especiallylwhich can look like1in some fonts). Consider using more descriptive names likerow_idxandcol_idxordim1anddim2to improve code readability.Apply this diff to improve variable naming:
# 4 dim if self.b_preshuffle: if is_transposed: for j in T.serial(warp_cols): for local_id in T.vectorized(k_pack * local_size_b): row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( + row_idx, col_idx = ( warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[row_idx, col_idx, row, col] else: for j in T.serial(warp_cols): for local_id in T.vectorized(k_pack * local_size_b): row, col = T.meta_var(reverse_index_map(tx, local_id)) - l, r = ( + row_idx, col_idx = ( rk * (chunk // micro_size_k) + ki, warp_n * warp_cols + j, ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[row_idx, col_idx, row, col]
297-298: Consider clarifying the "4 dim" comment.The comment "# 4 dim" is somewhat vague. Consider expanding it to be more descriptive about what the 4 dimensions represent in the preshuffle context.
Apply this diff to improve the comment:
- # 4 dim + # 4D preshuffle layout: [blocks_n, blocks_k, micro_size, pack_size] if self.b_preshuffle:testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (2)
239-239: Remove duplicate kernel source printing.The kernel source is printed twice (lines 239 and 261), which seems unnecessary and could clutter the output during test execution.
Apply this diff to remove the duplicate print statement:
matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed, k_pack, b_preshuffle) - print(matmul) kernel = tilelang.compile(matmul) src_code = kernel.get_kernel_source() # src_code is the generated cuda source assert src_code is not NoneAlso applies to: 261-261
286-287: Consider removing or conditionalizing debug print statements.The print statements for
Candref_c(lines 286-287) appear to be debug outputs that should be removed or made conditional for production test code.Apply this diff to remove debug prints:
- print(C) - print(ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)Or if you want to keep them for debugging purposes, make them conditional:
- print(C) - print(ref_c) + if os.environ.get("DEBUG_TESTS"): + print("Computed:", C) + print("Reference:", ref_c) torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py(1 hunks)tilelang/intrinsics/mfma_macro_generator.py(4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (6)
tilelang/intrinsics/mfma_layout.py (1)
make_mfma_swizzle_layout(130-152)tilelang/intrinsics/mfma_macro_generator.py (5)
MatrixCoreIntrinEmitter(14-427)ldmatrix_a(233-272)ldmatrix_b(274-343)mfma(345-377)stmatrix(379-427)tilelang/transform/simplify.py (1)
simplify_prim_func(53-59)tilelang/testing/__init__.py (1)
set_random_seed(20-25)tilelang/language/__init__.py (2)
annotate_layout(103-141)use_swizzle(94-100)tilelang/jit/__init__.py (1)
compile(33-86)
tilelang/intrinsics/mfma_macro_generator.py (3)
tilelang/tileop/gemm/gemm_base.py (2)
k_pack(110-111)chunk(62-63)tilelang/language/ast/ir.py (1)
meta_var(1731-1750)src/tl_templates/hip/gemm.h (1)
reverse_index_map(102-106)
🪛 Ruff (0.12.2)
tilelang/intrinsics/mfma_macro_generator.py
303-303: Ambiguous variable name: l
(E741)
314-314: Ambiguous variable name: l
(E741)
326-326: Ambiguous variable name: l
(E741)
336-336: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: bot-task
- GitHub Check: format-check
🔇 Additional comments (3)
tilelang/intrinsics/mfma_macro_generator.py (2)
56-56: LGTM! Clean addition of the preshuffle parameter.The new
b_preshuffleparameter is properly initialized with a sensible default value ofFalse, ensuring backward compatibility. The initialization follows the established pattern for optional parameters in the class.Also applies to: 76-77, 146-149
303-306: Verify preshuffle 4D indexing ([l, r, row, col])Preshuffle sets l,r = (warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki) and then indexes B_shared_buf[l, r, row, col]; confirm B_shared_shape / B_shape use the same axis order for both b_transposed cases by inspecting tilelang/intrinsics/mfma_macro_generator.py (≈303–320) and testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (≈64–74). Compare with mma_macro_generator.py usages (≈916–967) and reorder indices (or B_shared_shape) if the axis order does not match.
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (1)
66-68: Resolved — shuffle_weight matches kernel B_shapeshuffle_weight permutes to (N//BN, K//BK, BN, BK) when is_transpose (and (K//BK, N//BN, BK, BN) otherwise), which matches the kernel's B_shape / B_shared_shape ordering used for b_preshuffle. Default layout=(16,32) aligns with micro_size_y=16 and micro_size_k=32 used in the int8 tests; if you enable preshuffle for other dtypes, pass an appropriate layout.
| N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) | ||
| assert N % BN == 0 | ||
| assert K % BK == 0 | ||
|
|
||
| x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) | ||
| x = x.permute(0, 2, 1, 3) | ||
| return x.contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Add validation for tensor dimensions in shuffle_weight.
The function assumes 2D input tensors but doesn't validate this assumption. Consider adding a check to ensure the input tensor has exactly 2 dimensions.
Apply this diff to add dimension validation:
def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
+ if x.ndim != 2:
+ raise ValueError(f"Expected 2D tensor, got {x.ndim}D tensor")
IN, IK = layout
BK = IK * k_pack
BN = IN📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) | |
| assert N % BN == 0 | |
| assert K % BK == 0 | |
| x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) | |
| x = x.permute(0, 2, 1, 3) | |
| return x.contiguous() | |
| def shuffle_weight( | |
| x: torch.Tensor, | |
| layout=(16, 32), | |
| k_pack=1, | |
| is_transpose=False, | |
| ) -> torch.Tensor: | |
| if x.ndim != 2: | |
| raise ValueError(f"Expected 2D tensor, got {x.ndim}D tensor") | |
| IN, IK = layout | |
| BK = IK * k_pack | |
| BN = IN | |
| N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2]) | |
| assert N % BN == 0 | |
| assert K % BK == 0 | |
| x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN) | |
| x = x.permute(0, 2, 1, 3) | |
| return x.contiguous() |
🤖 Prompt for AI Agents
In testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py around lines 217 to
223, the shuffle_weight code assumes a 2D input but does not validate that; add
an explicit check at the start of this block to ensure x.dim() == 2 (or x.ndim
== 2) and raise a clear ValueError if not (include actual ndim in the message),
then proceed with the existing asserts and reshaping — this prevents confusing
errors later when non-2D tensors are passed.
| B.T.to(torch.float32)).to(getattr(torch, out_dtype)) | ||
| elif a_transposed and not b_transposed: | ||
| # Get Reference Result | ||
| ref_c = torch.matmul(A.Tto(torch.float32), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix syntax error: missing dot operator.
There's a syntax error on line 276 where the dot operator is missing between A.T and to.
Apply this diff to fix the syntax error:
- ref_c = torch.matmul(A.Tto(torch.float32),
+ ref_c = torch.matmul(A.T.to(torch.float32),📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| ref_c = torch.matmul(A.Tto(torch.float32), | |
| ref_c = torch.matmul(A.T.to(torch.float32), |
🤖 Prompt for AI Agents
In testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py around line 276,
there's a syntax error where the dot operator is missing between A.T and to;
replace the incorrect call `A.Tto(torch.float32,` with the correct chained
attribute call `A.T.to(torch.float32` so the transpose is followed by
`.to(...)`.
Co-authored-by: Jiaxing Ding <jiaxing.ding@bytedance.com>
Summary by CodeRabbit
New Features
Tests