Skip to content

Conversation

@Paran0idy
Copy link
Contributor

@Paran0idy Paran0idy commented Sep 11, 2025

  • Support preshuffle weight for AMD MFMA.

Summary by CodeRabbit

  • New Features

    • Added optional preshuffled-weight support for AMD MFMA-based GEMM kernels, working with transposed and non-transposed layouts. Improves B-loading and shared-memory handling with a preshuffle-aware path, configurable via kernel options.
  • Tests

    • Introduced a ROCm-only test suite that validates correctness and benchmarks GEMM across varied sizes, dtypes, transposition flags, k-pack values, and preshuffle modes. Includes utilities to shuffle weights for reference comparisons.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 11, 2025

Walkthrough

Adds a new ROCm test module to validate an AMD MFMA GEMM path supporting B-presuffle layouts and integrates a b_preshuffle code path into MatrixCoreIntrinEmitter, modifying its initialization and B-tile load logic.

Changes

Cohort / File(s) Summary
MFMA intrinsics emitter
tilelang/intrinsics/mfma_macro_generator.py
Extends MatrixCoreIntrinEmitter with a b_preshuffle flag and initializer. Updates ldmatrix_b to add a 4D B load path when preshuffle is enabled; preserves existing 2D load path otherwise.
AMD MFMA GEMM tests
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py
Introduces tl_matmul kernel generator using MatrixCoreIntrinEmitter with optional B preshuffle and k_pack. Adds shuffle_weight, correctness/benchmark harness, and ROCm-gated parameterized test.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant Test as PyTest (ROCm)
  participant TL as tl_matmul Builder
  participant Emitter as MatrixCoreIntrinEmitter
  participant Kernel as GEMM Kernel
  participant Torch as PyTorch Ref

  Test->>TL: build tl_matmul(M,N,K, dtypes, flags: b_preshuffle, k_pack, transposes)
  TL->>Emitter: __init__(..., b_preshuffle)
  Emitter-->>TL: configured emitter
  Note over Emitter,Kernel: Emitter selects B load path<br/>(4D preshuffle vs 2D standard)

  Test->>Test: prepare tensors (optional shuffle_weight on B)
  Test->>Kernel: run kernel(A, B[, preshuffled], C)
  Kernel->>Kernel: load A/B tiles (B uses preshuffle path if enabled)
  Kernel->>Kernel: MFMA accumulate and store C
  Test->>Torch: compute reference gemm (with transposes)
  Test->>Test: compare outputs, benchmark
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Pre-merge checks (2 passed, 1 warning)

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[AMD] support preshuffle weight mfma" accurately and concisely captures the primary change in this PR — adding preshuffled-weight support for AMD MFMA paths and tests. It is a short, single-line summary directly related to the modified mfma generator and the new AMD test, contains no extraneous noise, and is clear for a teammate scanning the history. The phrasing is specific enough to identify the feature change without listing file-level details.

Poem

A hop, a skip, a preshuffle twist,
My matrix snacks can’t be dismissed.
I stash B tiles in tidy rows,
MFMA hums, the throughput grows.
With ROCm breeze and kernels bright—
I thump my paws: the math is right. 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @Paran0idy, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces support for preshuffled weights within AMD's Matrix Fused Multiply-Add (MFMA) operations. This enhancement allows for optimized handling of specific weight layouts, improving performance for relevant matrix multiplication workloads.

Highlights

  • Preshuffle Weight Support: Implemented the capability to handle preshuffled weights for AMD MFMA operations.
  • MFMA Emitter Update: The MatrixCoreIntrinEmitter now includes a b_preshuffle parameter to control the new behavior.
  • Optimized B Tensor Loading: Modified the _warp_ldmatrix_b function to correctly load 4-dimensional B tensors when preshuffling is enabled.
  • New Test Coverage: A new test file test_tilelang_gemm_mfma_preshuffle.py has been added to thoroughly validate the correctness and functionality of the preshuffle weight feature.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces support for pre-shuffled weights for AMD MFMA, a valuable performance optimization. The changes include a new test file for validation and updates to the MFMA macro generator. My review has identified a critical bug within the new test code, specifically a typo that affects correctness checking for transposed matrices. I have also provided suggestions to improve code maintainability by refactoring duplicated logic in the macro generator and by decoupling a helper function from hardcoded kernel parameters. Addressing these points will enhance the quality and robustness of the implementation.

Comment on lines +270 to +284
if a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.T.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
B.to(torch.float32)).to(getattr(torch, out_dtype))
elif not a_transposed and b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32),
B.T.to(torch.float32)).to(getattr(torch, out_dtype))
else:
# Get Reference Result
ref_c = torch.matmul(A.to(torch.float32), B.to(torch.float32)).to(getattr(torch, out_dtype))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

There's a typo on line 276: A.Tto(torch.float32) should be A.T.to(torch.float32).

Although this code path is not exercised by the current tests (no test case sets a_transposed=True), this is a critical bug that should be fixed.

Additionally, this entire conditional block can be simplified to improve readability and maintainability, which would also help prevent such typos.

    A_ref = A.T if a_transposed else A
    B_ref = B.T if b_transposed else B
    ref_c = torch.matmul(A_ref.to(torch.float32), B_ref.to(torch.float32)).to(getattr(torch, out_dtype))

Comment on lines +207 to +213
def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
IN, IK = layout
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The shuffle_weight function uses a hardcoded default layout=(16, 32). These values are tightly coupled with micro_size_y and micro_size_k from the tl_matmul kernel definition. This makes the function less flexible and could lead to errors if the kernel parameters change.

To improve modularity and reduce this coupling, consider passing micro_size_n and micro_size_k directly to the function instead of the layout tuple.

You would then update the call site in assert_tl_matmul_correctness like this:

if b_preshuffle:
    micro_size_k = 32 if in_dtype == "int8" else 16
    micro_size_y = 16
    B_preshuffle = shuffle_weight(
        B_preshuffle, 
        micro_size_n=micro_size_y, 
        micro_size_k=micro_size_k, 
        k_pack=k_pack, 
        is_transpose=b_transposed
    )
    kernel(A, B_preshuffle, C)
Suggested change
def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
IN, IK = layout
def shuffle_weight(
x: torch.Tensor,
micro_size_n: int,
micro_size_k: int,
k_pack: int = 1,
is_transpose: bool = False,
) -> torch.Tensor:
IN, IK = micro_size_n, micro_size_k

Comment on lines +298 to +341
if self.b_preshuffle:
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_cols + j,
rk * (chunk // micro_size_k) + ki,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
row,
col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * (chunk // micro_size_k) + ki,
warp_n * warp_cols + j,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
row,
col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row,
r + col]
if is_transposed:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
warp_n * warp_col_tiles + j * micro_size_y,
rk * chunk + ki * (k_pack * micro_size_k),
)
B_local_buf[j * k_pack * local_size_b +
local_id] = B_shared_buf[l + row, r + col]
else:
for j in T.serial(warp_cols):
for local_id in T.vectorized(k_pack * local_size_b):
row, col = T.meta_var(reverse_index_map(tx, local_id))
l, r = (
rk * chunk + ki * (k_pack * micro_size_k),
warp_n * warp_col_tiles + j * micro_size_y,
)
B_local_buf[j * k_pack * local_size_b +
local_id] = B_shared_buf[l + row, r + col]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic in _warp_ldmatrix_b is split into four branches based on self.b_preshuffle and is_transposed, leading to significant code duplication. The loop structure, reverse_index_map call, and buffer assignment are repeated in each branch.

This duplication makes the code harder to read and maintain. Please consider refactoring to consolidate the common logic and isolate the parts that truly differ. For example, you could move the conditional logic for index calculation inside the loops.

Here's a possible refactoring approach:

@T.macro
def _warp_ldmatrix_b(...):
    tx, warp_n, _ = self.extract_thread_binding(thread_binding)

    for j in T.serial(warp_cols):
        for local_id in T.vectorized(k_pack * local_size_b):
            row, col = T.meta_var(reverse_index_map(tx, local_id))
            
            if self.b_preshuffle:
                if is_transposed:
                    l, r = warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki
                else:
                    l, r = rk * (chunk // micro_size_k) + ki, warp_n * warp_cols + j
                B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r, row, col]
            else:
                if is_transposed:
                    l, r = warp_n * warp_col_tiles + j * micro_size_y, rk * chunk + ki * (k_pack * micro_size_k)
                else:
                    l, r = rk * chunk + ki * (k_pack * micro_size_k), warp_n * warp_col_tiles + j * micro_size_y
                B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l + row, r + col]

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🧹 Nitpick comments (4)
tilelang/intrinsics/mfma_macro_generator.py (2)

297-320: Consider renaming ambiguous variables for better readability.

While l and r are functional variable names, they can be visually ambiguous (especially l which can look like 1 in some fonts). Consider using more descriptive names like row_idx and col_idx or dim1 and dim2 to improve code readability.

Apply this diff to improve variable naming:

            # 4 dim
            if self.b_preshuffle:
                if is_transposed:
                    for j in T.serial(warp_cols):
                        for local_id in T.vectorized(k_pack * local_size_b):
                            row, col = T.meta_var(reverse_index_map(tx, local_id))
-                            l, r = (
+                            row_idx, col_idx = (
                                warp_n * warp_cols + j,
                                rk * (chunk // micro_size_k) + ki,
                            )
-                            B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
+                            B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[row_idx, col_idx,
                                                                                             row,
                                                                                             col]
                else:
                    for j in T.serial(warp_cols):
                        for local_id in T.vectorized(k_pack * local_size_b):
                            row, col = T.meta_var(reverse_index_map(tx, local_id))
-                            l, r = (
+                            row_idx, col_idx = (
                                rk * (chunk // micro_size_k) + ki,
                                warp_n * warp_cols + j,
                            )
-                            B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[l, r,
+                            B_local_buf[j * k_pack * local_size_b + local_id] = B_shared_buf[row_idx, col_idx,
                                                                                             row,
                                                                                             col]

297-298: Consider clarifying the "4 dim" comment.

The comment "# 4 dim" is somewhat vague. Consider expanding it to be more descriptive about what the 4 dimensions represent in the preshuffle context.

Apply this diff to improve the comment:

-            # 4 dim
+            # 4D preshuffle layout: [blocks_n, blocks_k, micro_size, pack_size]
            if self.b_preshuffle:
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (2)

239-239: Remove duplicate kernel source printing.

The kernel source is printed twice (lines 239 and 261), which seems unnecessary and could clutter the output during test execution.

Apply this diff to remove the duplicate print statement:

     matmul = tl_matmul(M, N, K, in_dtype, out_dtype, accum_dtype, a_transposed, b_transposed,
                        k_pack, b_preshuffle)
-    print(matmul)
     kernel = tilelang.compile(matmul)
     src_code = kernel.get_kernel_source()
     # src_code is the generated cuda source
     assert src_code is not None

Also applies to: 261-261


286-287: Consider removing or conditionalizing debug print statements.

The print statements for C and ref_c (lines 286-287) appear to be debug outputs that should be removed or made conditional for production test code.

Apply this diff to remove debug prints:

-    print(C)
-    print(ref_c)
     torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)

Or if you want to keep them for debugging purposes, make them conditional:

-    print(C)
-    print(ref_c)
+    if os.environ.get("DEBUG_TESTS"):
+        print("Computed:", C)
+        print("Reference:", ref_c)
     torch.testing.assert_close(C, ref_c, rtol=1e-2, atol=1e-2)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 409ab83 and 1f46d89.

📒 Files selected for processing (2)
  • testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (1 hunks)
  • tilelang/intrinsics/mfma_macro_generator.py (4 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (6)
tilelang/intrinsics/mfma_layout.py (1)
  • make_mfma_swizzle_layout (130-152)
tilelang/intrinsics/mfma_macro_generator.py (5)
  • MatrixCoreIntrinEmitter (14-427)
  • ldmatrix_a (233-272)
  • ldmatrix_b (274-343)
  • mfma (345-377)
  • stmatrix (379-427)
tilelang/transform/simplify.py (1)
  • simplify_prim_func (53-59)
tilelang/testing/__init__.py (1)
  • set_random_seed (20-25)
tilelang/language/__init__.py (2)
  • annotate_layout (103-141)
  • use_swizzle (94-100)
tilelang/jit/__init__.py (1)
  • compile (33-86)
tilelang/intrinsics/mfma_macro_generator.py (3)
tilelang/tileop/gemm/gemm_base.py (2)
  • k_pack (110-111)
  • chunk (62-63)
tilelang/language/ast/ir.py (1)
  • meta_var (1731-1750)
src/tl_templates/hip/gemm.h (1)
  • reverse_index_map (102-106)
🪛 Ruff (0.12.2)
tilelang/intrinsics/mfma_macro_generator.py

303-303: Ambiguous variable name: l

(E741)


314-314: Ambiguous variable name: l

(E741)


326-326: Ambiguous variable name: l

(E741)


336-336: Ambiguous variable name: l

(E741)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: bot-task
  • GitHub Check: format-check
🔇 Additional comments (3)
tilelang/intrinsics/mfma_macro_generator.py (2)

56-56: LGTM! Clean addition of the preshuffle parameter.

The new b_preshuffle parameter is properly initialized with a sensible default value of False, ensuring backward compatibility. The initialization follows the established pattern for optional parameters in the class.

Also applies to: 76-77, 146-149


303-306: Verify preshuffle 4D indexing ([l, r, row, col])

Preshuffle sets l,r = (warp_n * warp_cols + j, rk * (chunk // micro_size_k) + ki) and then indexes B_shared_buf[l, r, row, col]; confirm B_shared_shape / B_shape use the same axis order for both b_transposed cases by inspecting tilelang/intrinsics/mfma_macro_generator.py (≈303–320) and testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (≈64–74). Compare with mma_macro_generator.py usages (≈916–967) and reorder indices (or B_shared_shape) if the axis order does not match.

testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py (1)

66-68: Resolved — shuffle_weight matches kernel B_shape

shuffle_weight permutes to (N//BN, K//BK, BN, BK) when is_transpose (and (K//BK, N//BN, BK, BN) otherwise), which matches the kernel's B_shape / B_shared_shape ordering used for b_preshuffle. Default layout=(16,32) aligns with micro_size_y=16 and micro_size_k=32 used in the int8 tests; if you enable preshuffle for other dtypes, pass an appropriate layout.

Comment on lines +217 to +223
N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
assert N % BN == 0
assert K % BK == 0

x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
x = x.permute(0, 2, 1, 3)
return x.contiguous()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Add validation for tensor dimensions in shuffle_weight.

The function assumes 2D input tensors but doesn't validate this assumption. Consider adding a check to ensure the input tensor has exactly 2 dimensions.

Apply this diff to add dimension validation:

 def shuffle_weight(
         x: torch.Tensor,
         layout=(16, 32),
         k_pack=1,
         is_transpose=False,
 ) -> torch.Tensor:
+    if x.ndim != 2:
+        raise ValueError(f"Expected 2D tensor, got {x.ndim}D tensor")
     IN, IK = layout
     BK = IK * k_pack
     BN = IN
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
assert N % BN == 0
assert K % BK == 0
x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
x = x.permute(0, 2, 1, 3)
return x.contiguous()
def shuffle_weight(
x: torch.Tensor,
layout=(16, 32),
k_pack=1,
is_transpose=False,
) -> torch.Tensor:
if x.ndim != 2:
raise ValueError(f"Expected 2D tensor, got {x.ndim}D tensor")
IN, IK = layout
BK = IK * k_pack
BN = IN
N, K = (x.shape[-2], x.shape[-1]) if is_transpose else (x.shape[-1], x.shape[-2])
assert N % BN == 0
assert K % BK == 0
x = x.view(N // BN, BN, K // BK, BK) if is_transpose else x.view(K // BK, BK, N // BN, BN)
x = x.permute(0, 2, 1, 3)
return x.contiguous()
🤖 Prompt for AI Agents
In testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py around lines 217 to
223, the shuffle_weight code assumes a 2D input but does not validate that; add
an explicit check at the start of this block to ensure x.dim() == 2 (or x.ndim
== 2) and raise a clear ValueError if not (include actual ndim in the message),
then proceed with the existing asserts and reshaping — this prevents confusing
errors later when non-2D tensors are passed.

B.T.to(torch.float32)).to(getattr(torch, out_dtype))
elif a_transposed and not b_transposed:
# Get Reference Result
ref_c = torch.matmul(A.Tto(torch.float32),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix syntax error: missing dot operator.

There's a syntax error on line 276 where the dot operator is missing between A.T and to.

Apply this diff to fix the syntax error:

-        ref_c = torch.matmul(A.Tto(torch.float32),
+        ref_c = torch.matmul(A.T.to(torch.float32),
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
ref_c = torch.matmul(A.Tto(torch.float32),
ref_c = torch.matmul(A.T.to(torch.float32),
🤖 Prompt for AI Agents
In testing/python/amd/test_tilelang_gemm_mfma_preshuffle.py around line 276,
there's a syntax error where the dot operator is missing between A.T and to;
replace the incorrect call `A.Tto(torch.float32,` with the correct chained
attribute call `A.T.to(torch.float32` so the transpose is followed by
`.to(...)`.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants