-
Notifications
You must be signed in to change notification settings - Fork 333
[AMD] refactor MatrixCoreIntrinEmitter #860
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughAdds two new AMD MFMA intrinsic tests, introduces a preshuffle-aware matmul path with a new b_g2l_load option in tests, and extends MFMA intrinsics with preshuffle-enabled emitters and signatures, including a new MatrixCorePreshuffleIntrinEmitter and modified ldmatrix load paths and loop structure. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant T as Test
participant TL as tl_matmul
participant EM as MatrixCorePreshuffleIntrinEmitter
participant A as A Loader (S2L)
participant B as B Loader (S2L/G2L)
participant P as Pipelined(num_ko,num_ki)
participant C as Store (stmatrix)
T->>TL: call tl_matmul(..., b_g2l_load)
TL->>EM: construct emitter (a_preshuffle,b_preshuffle,k_pack,...)
TL->>P: iterate over num_ko (outer) and num_ki (inner)
loop per ki
P->>A: ldmatrix_a (shared→local)
alt b_g2l_load == true
P->>B: ldmatrix_b (global→local)
else
P->>B: ldmatrix_b (shared→local)
end
P->>EM: mma (accumulate)
end
TL->>C: stmatrix with pid_m/pid_n
C-->>T: completion
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
Summary of ChangesHello @Paran0idy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a specialized intrinsic emitter, Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request refactors the MatrixCoreIntrinEmitter to support weight preshuffling by introducing a new MatrixCorePreshuffleIntrinEmitter class. The changes are well-structured, moving specialized logic into a subclass. I've provided a few suggestions to improve code quality, such as removing duplicated code by using inheritance properly, removing leftover debug statements, and cleaning up test cases. I also noted the removal of some explicit type checks which could be reconsidered for better error reporting.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 4
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py (1)
214-214: Fix the missing dot in the method callThere's a typo in the
A.Ttomethod call - it should beA.T.to.- ref_c = torch.matmul(A.Tto(torch.float32), + ref_c = torch.matmul(A.T.to(torch.float32),testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (1)
248-248: Fix the missing dot in method callSame typo as in the other test file -
A.Ttoshould beA.T.to.- ref_c = torch.matmul(A.Tto(torch.float32), + ref_c = torch.matmul(A.T.to(torch.float32),
🧹 Nitpick comments (5)
tilelang/intrinsics/mfma_macro_generator.py (5)
56-56: Consider better naming for the preshuffle parameterThe parameter
b_preshufflein the base class constructor might cause confusion since there's alsoa_preshufflein the derived class. Consider either addinga_preshuffleto the base class or documenting why only B preshuffle is needed here.
146-148: Initialize preshuffle attributes before useThe
_initialize_b_preshufflemethod referencesself.b_preshufflebut doesn't handle the case where the attribute might not exist. Consider initializing it first.def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False): if b_preshuffle is not None: + self.b_preshuffle = False # Initialize with default self.b_preshuffle = b_preshuffleOr better yet:
def _initialize_b_preshuffle(self, b_preshuffle: Optional[bool] = False): - if b_preshuffle is not None: - self.b_preshuffle = b_preshuffle + self.b_preshuffle = b_preshuffle if b_preshuffle is not None else False
300-300: Consider renaming ambiguous variable namesVariables
landrcould be confused with numbers (1) or be unclear in meaning. Consider more descriptive names likeleft_idx, right_idxorrow_base, col_base.- l, r = ( + row_base, col_base = ( warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k), ) - B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, - r + col] + B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[row_base + row, + col_base + col]Also applies to: 311-311
459-609: Consider extracting common patterns in ldmatrix methodsThe global and shared variants of ldmatrix_a and ldmatrix_b have very similar patterns. Consider extracting common logic to reduce code duplication.
Consider creating a helper method for the common pattern:
def _ldmatrix_helper(self, local_buf, buf, ki, thread_binding, rk, is_b=False, is_global=False, is_transposed=False): # Common extraction and indexing logic ...This would reduce duplication and make the code more maintainable.
487-487: Variable naming consistencyThe static analysis correctly flags the use of
las an ambiguous variable name throughout the preshuffle methods. For consistency with the parent class and better readability, consider using more descriptive names.Also applies to: 496-496, 515-515, 526-526, 562-562, 571-571, 590-590, 600-600
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py(1 hunks)testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py(10 hunks)tilelang/intrinsics/mfma_macro_generator.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py (1)
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (1)
assert_tl_matmul_correctness(197-261)
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (1)
tilelang/intrinsics/mfma_macro_generator.py (7)
MatrixCorePreshuffleIntrinEmitter(405-609)ldmatrix_a(233-272)ldmatrix_a(459-532)ldmatrix_b(274-318)ldmatrix_b(534-609)mfma(320-352)stmatrix(354-402)
tilelang/intrinsics/mfma_macro_generator.py (2)
tilelang/intrinsics/mma_macro_generator.py (6)
ldmatrix_a(199-261)ldmatrix_a(753-857)get_thread_binding(151-157)extract_thread_binding(167-197)ldmatrix_b(263-333)ldmatrix_b(859-971)tilelang/language/kernel.py (4)
KernelLaunchFrame(72-206)Current(115-121)get_thread_binding(151-156)get_thread_binding(261-265)
🪛 Ruff (0.13.1)
tilelang/intrinsics/mfma_macro_generator.py
300-300: Ambiguous variable name: l
(E741)
311-311: Ambiguous variable name: l
(E741)
487-487: Ambiguous variable name: l
(E741)
496-496: Ambiguous variable name: l
(E741)
515-515: Ambiguous variable name: l
(E741)
526-526: Ambiguous variable name: l
(E741)
562-562: Ambiguous variable name: l
(E741)
571-571: Ambiguous variable name: l
(E741)
590-590: Ambiguous variable name: l
(E741)
600-600: Ambiguous variable name: l
(E741)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
🔇 Additional comments (8)
testing/python/amd/test_tilelang_gemm_mfma_intrinsic.py (1)
237-240: LGTM!The new test cases add valuable coverage for non-transposed B matrix configurations.
tilelang/intrinsics/mfma_macro_generator.py (2)
296-317: Refactoring simplified the ldmatrix_b logic nicelyThe removal of complex preshuffle-specific logic and the cleaner implementation with consistent l, r mapping improves code maintainability.
405-452: Well-structured preshuffle emitter classThe new
MatrixCorePreshuffleIntrinEmitterclass appropriately extends the base emitter with preshuffle-specific functionality. The constructor properly initializes all necessary attributes.testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (5)
6-6: Good adoption of the new preshuffle emitterThe import change correctly switches to using
MatrixCorePreshuffleIntrinEmitterfor preshuffle-aware testing.
24-24: Well-integrated b_g2l_load parameterThe addition of
b_g2l_loadparameter enables testing of both global-to-local and shared-to-local B matrix loading paths, providing comprehensive test coverage.Also applies to: 207-207, 209-209
41-44: Verify the preshuffle configuration changesThe warp tile dimensions change from 32x32 to 64x16 for preshuffle mode, and chunk size increases significantly. Please verify these are the intended optimizations for preshuffle mode.
Are these specific tile and chunk sizes validated for optimal performance with preshuffle on AMD hardware?
153-162: Clean separation of B loading pathsThe conditional logic for G2L vs S2L B matrix loading is well-structured and provides clear path selection based on the
b_g2l_loadflag.
267-284: Good test coverage for preshuffle scenariosThe test cases comprehensively cover different configurations including transposed/non-transposed B matrices and different k_pack values with preshuffle enabled.
| def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): | ||
| if a_preshuffle is not None: | ||
| self.a_preshuffle = a_preshuffle | ||
| if b_preshuffle is not None: | ||
| self.b_preshuffle = b_preshuffle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Improve preshuffle initialization logic
The condition if a_preshuffle is not None will be True even when a_preshuffle=False is passed. This could lead to unexpected behavior.
def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool):
- if a_preshuffle is not None:
- self.a_preshuffle = a_preshuffle
- if b_preshuffle is not None:
- self.b_preshuffle = b_preshuffle
+ self.a_preshuffle = a_preshuffle if a_preshuffle is not None else False
+ self.b_preshuffle = b_preshuffle if b_preshuffle is not None else False📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): | |
| if a_preshuffle is not None: | |
| self.a_preshuffle = a_preshuffle | |
| if b_preshuffle is not None: | |
| self.b_preshuffle = b_preshuffle | |
| def _initialize_preshuffle(self, a_preshuffle: bool, b_preshuffle: bool): | |
| self.a_preshuffle = a_preshuffle if a_preshuffle is not None else False | |
| self.b_preshuffle = b_preshuffle if b_preshuffle is not None else False |
🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around lines 453-457, change the
preshuffle initialization to explicitly treat None as "no-op" and accept False
as a valid value: check "if a_preshuffle is None: pass else: self.a_preshuffle =
a_preshuffle" (and the same for b_preshuffle) so that passing False correctly
sets the attribute rather than being misinterpreted as omitted.
| if self.a_preshuffle is False: | ||
| return super().ldmatrix_a(A_local_buf, A_buf, ki, rk) | ||
|
|
||
| def _warp_ldmatrix_a_global( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing @T.macro decorator for _warp_ldmatrix_a_global
The _warp_ldmatrix_a_global function is missing the @T.macro decorator, unlike its shared counterpart and other similar functions in the codebase.
+ @T.macro
def _warp_ldmatrix_a_global(
A_local_buf,
A_buf,📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def _warp_ldmatrix_a_global( | |
| @T.macro | |
| def _warp_ldmatrix_a_global( |
🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around line 475, the helper
function _warp_ldmatrix_a_global is missing the @T.macro decorator; add the
@T.macro decorator immediately above its def line to match its shared
counterpart and other similar functions so it is registered as a TIR macro,
ensuring consistent behavior and enabling the macro to be expanded where
expected.
| A_local_buf[i * k_pack * local_size_a + local_id] = A_shared_buf[l, r, row, | ||
| col] | ||
| else: | ||
| print(self.a_preshuffle) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove debug print statement
There's a print statement that appears to be for debugging purposes and should be removed before merging.
else:
- print(self.a_preshuffle)
for i in T.serial(warp_rows):📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| print(self.a_preshuffle) | |
| else: | |
| for i in T.serial(warp_rows): |
🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around line 522, there is a
leftover debug print statement "print(self.a_preshuffle)"; remove this line so
the function no longer emits debug output, and run tests/lint to ensure no other
debug prints remain.
| is_global = pid_m is not None and pid_n is not None | ||
|
|
||
| if self.b_preshuffle is False: | ||
| return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect method signature in super() call
The super().ldmatrix_b() call includes extra parameters (pid_m, pid_n) that don't exist in the parent class method signature.
if self.b_preshuffle is False:
- return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n)
+ return super().ldmatrix_b(B_local_buf, B_buf, ki, rk)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| return super().ldmatrix_b(B_local_buf, B_buf, ki, rk, pid_m, pid_n) | |
| return super().ldmatrix_b(B_local_buf, B_buf, ki, rk) |
🤖 Prompt for AI Agents
In tilelang/intrinsics/mfma_macro_generator.py around line 547, the
super().ldmatrix_b(...) call passes two extra arguments (pid_m, pid_n) that the
parent method doesn't accept; remove those extra parameters and call
super().ldmatrix_b with the expected parameters (i.e., B_local_buf, B_buf, ki,
rk) so the call matches the parent signature.
Summary by CodeRabbit
New Features
Tests