Skip to content

Conversation

@txs19991
Copy link
Contributor

@txs19991 txs19991 commented Sep 11, 2025

Summary by CodeRabbit

  • New Features

    • FP8 GEMM example for AMD with autotuning, two kernel variants, and end-to-end validation.
    • FP8 support added to HIP MFMA path with widened vector handling.
  • Performance

    • k-pack parameterization and vector-width changes improve GEMM throughput and data movement.
  • Refactor

    • GEMM layout and swizzle APIs extended to accept k_pack; CDNA binding behavior for B adjusted.
  • Documentation

    • Added example illustrating FP8 GEMM usage and correctness checks.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 11, 2025

Note

Other AI code review bot(s) detected

CodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review.

Walkthrough

Adds an FP8 GEMM example for AMD using TileLang, introduces FP8 MFMA support and vectorization in HIP templates, threads k_pack through CDNA GEMM layouts and fragment builders, updates public layout APIs to accept k_pack, and adjusts GemmNode layout inference to pass kPack and remove B local.fragment binding.

Changes

Cohort / File(s) Summary
FP8 GEMM example (AMD)
examples/gemm_fp8/example_tilelang_gemm_amd.py
New example: FP8 GEMM with TileLang including autotuned kernels (ss/rs), k_pack-aware tiling, reference implementation, input supply, and an end-to-end test.
CDNA GEMM layout parameterization
src/layout/gemm_layouts.cc, src/layout/layout.h
Thread k_pack through AB/A fragment builders; add AB16x16/16x32 CDNA (and transposed) variants; matrix-core swizzle layout accepts kPack (default=1); public makeGemmFragmentACDNA signature updated to include k_pack.
GEMM op layout inference
src/op/gemm.cc, src/op/gemm.h
GemmNode::InferLayout (CDNA) now passes kPack to makeGemmFragmentACDNA; removed handling that bound B as local.fragment in CDNA; updated public header signature to include kPack.
HIP templates: MFMA, FP8, and vectorization
src/tl_templates/hip/gemm.h, src/tl_templates/hip/hip_fp8.h
Add MfmaTraits specialization for fp8_e4_t (HIP_FP8_ENABLED); generalize micro_size_k and introduce vec_size based on element type; replace hard-coded 4-wide vector logic with vec_size; add fp8_e4_4_t wrapper struct (replacing alias) with member access and conversions.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  actor User
  participant Example as example_tilelang_gemm_amd.py
  participant TL as TileLang (autotune/jit)
  participant Kernel as gemm_fp8_{ss|rs}
  participant GPU as AMD CDNA MFMA

  User->>Example: test_gemm_fp8(M,N,K)
  Example->>TL: fp8_matmul(..., gemm_type, k_pack)
  TL->>Kernel: select & compile kernel (tiling, k_pack)
  Kernel->>GPU: launch with FP8 inputs (mfma fp8 path)
  GPU-->>Kernel: compute results
  Kernel-->>Example: return C
  Example->>Example: compare with ref_program -> pass/fail
Loading
sequenceDiagram
  autonumber
  participant GemmNode as GemmNode::InferLayout (CDNA)
  participant Layout as Fragment/Layout builders

  GemmNode->>Layout: makeGemmFragmentACDNA(..., dtype_bits, kPack, trans_A)
  note right of Layout: select k_pack-aware AB/CDNA fragment variant<br/>matrix-core swizzle uses kPack
  GemmNode-->>GemmNode: B local.fragment binding for CDNA removed
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

I nibble tiles and pack the K,
FP8 hops in, and vectors play.
CDNA hums a rhythmic beat,
MFMA sings with tiny feet.
"passed~" I whisper — kernels complete. 🐇✨

Tip

👮 Agentic pre-merge checks are now available in preview!

Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.

  • Built-in checks – Quickly apply ready-made checks to enforce title conventions, require pull request descriptions that follow templates, validate linked issues for compliance, and more.
  • Custom agentic checks – Define your own rules using CodeRabbit’s advanced agentic capabilities to enforce organization-specific policies and workflows. For example, you can instruct CodeRabbit’s agent to verify that API documentation is updated whenever API schema files are modified in a PR. Note: Upto 5 custom checks are currently allowed during the preview period. Pricing for this feature will be announced in a few weeks.

Please see the documentation for more information.

Example:

reviews:
  pre_merge_checks:
    custom_checks:
      - name: "Undocumented Breaking Changes"
        mode: "warning"
        instructions: |
          Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).

Please share your feedback with us on this Discord post.

Pre-merge checks (3 passed)

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[AMD] support fp8 T.gemm" concisely identifies the primary change — adding FP8 GEMM support for AMD in the TileLang GEMM codepath — and directly maps to the changes in FP8 kernels, layout k_pack support, and HIP FP8 types. It is short, focused, and immediately conveys the main intent to reviewers familiar with the project.
Docstring Coverage ✅ Passed No functions found in the changes. Docstring coverage check skipped.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @txs19991, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enables 8-bit floating point (FP8) matrix multiplication on AMD GPUs by integrating specific hardware instructions and updating the underlying memory layout and fragment generation logic. It also includes a new example to demonstrate and validate the FP8 GEMM functionality within the tilelang framework.

Highlights

  • FP8 GEMM Support: Introduced support for FP8 (8-bit floating point) General Matrix Multiply (GEMM) operations, specifically targeting AMD GPUs. This includes the integration of the __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8 instruction for efficient FP8 computations.
  • New FP8 GEMM Example: Added a new Python example (examples/gemm_fp8/example_tilelang_gemm_amd.py) demonstrating FP8 GEMM with autotuning capabilities and reference program validation, showcasing both 'ss' and 'rs' GEMM types.
  • Generalized Layout and Fragment Generation: Modified existing GEMM fragment and layout generation functions (makeGemmFragmentAB16x16CDNA, makeGemmFragmentACDNA, etc.) to be more flexible with k_pack and element sizes, and introduced new 16x32 fragments to support FP8 specific MFMA instructions. Simplified the makeGemmABLayoutCDNA function by removing conditional padding logic.
  • FP8 Type Utility: Defined HIP_FP8_ENABLED and introduced a utility wrapper fp8_e4_4_t for __hip_fp8x4_e4m3_fnuz to provide convenient member access and improve usability of FP8 types in the codebase.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for fp8 GEMM on AMD GPUs, introducing a new example, updating layout definitions, and adding template specializations for HIP. The changes appear to correctly enable the new functionality. My review includes a few suggestions to enhance code quality and safety. Specifically, I've recommended refactoring in the new Python example to improve its structure and reduce code duplication. Additionally, I've identified a potential strict-aliasing issue in a C++ template and proposed a safer implementation.

Comment on lines 30 to 42
valid_configs = []

for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types):
valid_configs.append({
"block_M": m,
"block_N": n,
"block_K": k,
"num_stages": stages,
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
})
return valid_configs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The for-loop used to build valid_configs can be written more concisely and idiomatically as a list comprehension. This improves readability and is a more common pattern in Python.

  return [{
      "block_M": m,
      "block_N": n,
      "block_K": k,
      "num_stages": stages,
      "num_threads": t,
      "k_pack": kp,
      "gemm_type": gemm_type,
  } for m, n, k, stages, t, kp, gemm_type in itertools.product(
      block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types
  )]

Comment on lines 49 to 92
@T.prim_func
def gemm_fp8_rs(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_local = T.alloc_fragment((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_local)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)

T.copy(C_local, C[by * block_M, bx * block_N])

@T.prim_func
def gemm_fp8_ss(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
A_shared = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)

T.copy(C_local, C[by * block_M, bx * block_N])

if gemm_type == "ss":
return gemm_fp8_ss
elif gemm_type == "rs":
return gemm_fp8_rs
else:
raise ValueError(f"Invalid gemm_type: {gemm_type}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The functions gemm_fp8_rs and gemm_fp8_ss are almost identical, differing only in whether matrix A is stored in a local fragment (A_local) or shared memory (A_shared). You can avoid this code duplication by defining a single T.prim_func and using a conditional statement to allocate memory for A based on the gemm_type. This will make the code more maintainable.

    if gemm_type not in ["ss", "rs"]:
        raise ValueError(f"Invalid gemm_type: {gemm_type}")

    @T.prim_func
    def gemm_fp8(
            A: T.Tensor((M, K), dtype),
            B: T.Tensor((N, K), dtype),
            C: T.Tensor((M, N), accum_dtype),
    ):
        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
            if gemm_type == "rs":
                A_mem = T.alloc_fragment((block_M, block_K), dtype)
            else:  # ss
                A_mem = T.alloc_shared((block_M, block_K), dtype)
            B_shared = T.alloc_shared((block_N, block_K), dtype)
            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)

            T.clear(C_local)
            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
                T.copy(A[by * block_M, k * block_K], A_mem)
                T.copy(B[bx * block_N, k * block_K], B_shared)
                T.gemm(A_mem, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)

            T.copy(C_local, C[by * block_M, bx * block_N])

    return gemm_fp8

Comment on lines +60 to +61
int64_t a_val = *reinterpret_cast<const int64_t *>(a);
int64_t b_val = *reinterpret_cast<const int64_t *>(b);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Using reinterpret_cast followed by dereferencing can lead to strict aliasing violations, which is undefined behavior in C++. Although it might work with current compilers, it's safer to use memcpy to copy the bytes, which is guaranteed to be safe and well-defined.

    int64_t a_val, b_val;
    memcpy(&a_val, a, sizeof(int64_t));
    memcpy(&b_val, b, sizeof(int64_t));

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (4)
src/tl_templates/hip/hip_fp8.h (2)

46-55: Bit-packing uses signed shifts and strict-aliasing—make it well-defined.

Shifting signed char is UB; reinterpreting an int as fp8_e4_4_t is also UB. Use uint8_t/uint32_t and byte-wise stores into the underlying vector.

-  // reinterpret the 4 fp8_e4_t values to signed char value and shift
-  signed char x_char = *reinterpret_cast<signed char *>(&x);
-  signed char y_char = *reinterpret_cast<signed char *>(&y);
-  signed char z_char = *reinterpret_cast<signed char *>(&z);
-  signed char w_char = *reinterpret_cast<signed char *>(&w);
-  int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
-  return *reinterpret_cast<fp8_e4_4_t *>(&res);
+  // Extract raw bytes safely
+  const unsigned char x_u8 = *reinterpret_cast<const unsigned char*>(&x);
+  const unsigned char y_u8 = *reinterpret_cast<const unsigned char*>(&y);
+  const unsigned char z_u8 = *reinterpret_cast<const unsigned char*>(&z);
+  const unsigned char w_u8 = *reinterpret_cast<const unsigned char*>(&w);
+  fp8_e4_4_t out;
+  unsigned char* dst = reinterpret_cast<unsigned char*>(&out.data);
+  dst[0] = x_u8; dst[1] = y_u8; dst[2] = z_u8; dst[3] = w_u8;
+  return out;

57-74: Same UB in fp8_e4_8_t builder; construct via per-byte stores.

Avoid signed shifts and type-punning across unrelated types.

-  signed char x_char = *reinterpret_cast<signed char *>(&x);
-  signed char y_char = *reinterpret_cast<signed char *>(&y);
-  signed char z_char = *reinterpret_cast<signed char *>(&z);
-  signed char w_char = *reinterpret_cast<signed char *>(&w);
-  signed char v_char = *reinterpret_cast<signed char *>(&v);
-  signed char u_char = *reinterpret_cast<signed char *>(&u);
-  signed char t_char = *reinterpret_cast<signed char *>(&t);
-  signed char s_char = *reinterpret_cast<signed char *>(&s);
-  int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char;
-  int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char;
-  fp8_e4_8_t res;
-  res.x = *reinterpret_cast<fp8_e4_4_t *>(&a);
-  res.y = *reinterpret_cast<fp8_e4_4_t *>(&b);
-  return res;
+  const unsigned char xs[8] = {
+    *reinterpret_cast<const unsigned char*>(&x),
+    *reinterpret_cast<const unsigned char*>(&y),
+    *reinterpret_cast<const unsigned char*>(&z),
+    *reinterpret_cast<const unsigned char*>(&w),
+    *reinterpret_cast<const unsigned char*>(&v),
+    *reinterpret_cast<const unsigned char*>(&u),
+    *reinterpret_cast<const unsigned char*>(&t),
+    *reinterpret_cast<const unsigned char*>(&s),
+  };
+  fp8_e4_8_t res;
+  unsigned char* a = reinterpret_cast<unsigned char*>(&res.x.data);
+  unsigned char* b = reinterpret_cast<unsigned char*>(&res.y.data);
+  a[0]=xs[0]; a[1]=xs[1]; a[2]=xs[2]; a[3]=xs[3];
+  b[0]=xs[4]; b[1]=xs[5]; b[2]=xs[6]; b[3]=xs[7];
+  return res;
src/op/gemm.cc (1)

590-599: CDNA now forbids B.local.fragment—emit a clearer error.

Replace ICHECK(0) with a message that instructs users to place B in shared for CDNA.

-    } else {
-      ICHECK(0);
-    }
+    } else {
+      ICHECK(false) << "CDNA GEMM requires B in shared/shared.dyn; got scope: "
+                    << B.scope();
+    }
src/layout/gemm_layouts.cc (1)

418-441: Guard against maxPhase=0; add include.

If continuous < vecSize, innerDimLength/vecSize == 0 ⇒ modulo by 0 in phase. Add checks and include for std::min/max.

Apply:

 #include <cmath>
+#include <algorithm>

And within makeMatrixCoreSwizzleLayout:

   const int vecSize = (64 / element_size) * kPack;
   const int innerDimLength = continuous;
   const int typeWidthInBit = element_size;
+  ICHECK_GT(vecSize, 0);
+  ICHECK_GE(innerDimLength, vecSize)
+      << "continuous (" << innerDimLength << ") must be >= vecSize (" << vecSize << ")";
+  ICHECK_EQ(innerDimLength % vecSize, 0)
+      << "continuous must be a multiple of vecSize for swizzle";
🧹 Nitpick comments (15)
src/tl_templates/hip/hip_fp8.h (1)

3-3: Don't hard-define feature macros in public headers.

Defining HIP_FP8_ENABLED in a header can cause ODR/config skew across TUs. Prefer guarding or moving to build flags.

-#define HIP_FP8_ENABLED 1
+#ifndef HIP_FP8_ENABLED
+#define HIP_FP8_ENABLED 1
+#endif
examples/gemm_fp8/example_tilelang_gemm_amd.py (5)

10-12: Strengthen manual check beyond a single element.

Sampling more positions (e.g., a few random indices or a small tile) reduces false positives during autotune.

-def manual_check_prog(C, C_ref):
-  torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)
+def manual_check_prog(C, C_ref):
+  # Check a few elements to guard against localized errors
+  torch_assert_close(C[:1, :4], C_ref[:1, :4], rtol=0.01, atol=0.1)

13-20: Bind device from params to avoid device mismatches.

Using the incoming param device makes the example portable across CUDA/ROCm installs and multi-GPU setups.

-  a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz)
-  b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz)
+  dev = a_param.device
+  a = (torch.randn(M, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz)
+  b = (torch.randn(N, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz)

21-43: Config space looks sane; consider pruning incompatible tuples if needed.

If MFMA FP8 requires K divisible by 32 (or k_pack×16), you can filter here to cut invalid runs.


87-92: Minor: tighten error.

Consider listing valid options in the message for quicker debugging.

-        raise ValueError(f"Invalid gemm_type: {gemm_type}")
+        raise ValueError(f"Invalid gemm_type: {gemm_type}. Expected one of: 'ss', 'rs'.")

95-103: Leverage param devices and match tolerances with utils defaults.

Mirror the supply_prog device binding; utils default atol=1e-2, rtol=1e-2—OK.

-    a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz)
-    b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz)
+    dev = torch.device('cuda')
+    a = (torch.randn(M, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz)
+    b = (torch.randn(N, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz)
src/layout/layout.h (1)

166-168: Parameter naming is inconsistent across files (kfactor vs kPack/k_pack).

In this header it’s kfactor; in gemm_layouts.cc it’s kPack; in other places it’s k_pack. Pick one (suggest k_pack) for consistency.

src/tl_templates/hip/gemm.h (4)

77-79: New micro_size_k/vec_size are correct; add compile-time guards.

Assert integrality assumptions to prevent silent shape bugs.

Apply:

   static constexpr int micro_size_x = 16;
   static constexpr int micro_size_y = 16;
   static constexpr int micro_size_k = 32 / sizeof(A_type);
   static constexpr int vec_size = 8 / sizeof(A_type);
+  static_assert((micro_size_x * micro_size_k) % warp_size == 0,
+                "local_size_a must be integral");
+  static_assert((micro_size_y * micro_size_k) % warp_size == 0,
+                "local_size_b must be integral");

Also applies to: 88-89


175-177: Ensure 8/16B alignment for vectorized packing.

Local arrays feed 64-bit (fp8/i8) or 8-byte (f16x4/bf16x4) loads. Add alignment to avoid misaligned vector loads.

Apply:

-    A_type A_local[warp_rows * kPack * local_size_a];
-    B_type B_local[warp_cols * kPack * local_size_b];
+    alignas(16) A_type A_local[warp_rows * kPack * local_size_a];
+    alignas(16) B_type B_local[warp_cols * kPack * local_size_b];

241-252: Align B_local in rs path as well.

Apply:

-    B_type B_local[warp_cols * kPack * local_size_b];
+    alignas(16) B_type B_local[warp_cols * kPack * local_size_b];

166-167: Remove unused variable tx.

tx is set but never used in both bodies.

Apply:

-    auto tx = lane_id;
src/layout/gemm_layouts.cc (4)

62-69: Validate k_pack at entry.

Add a sanity check to prevent undefined layouts on k_pack <= 0.

Apply:

 Fragment makeGemmFragmentAB16x16CDNA(const int k_pack) {
+  ICHECK(k_pack >= 1) << "k_pack must be >= 1";

71-78: Same k_pack validation for transposed variant.

Apply:

 Fragment makeGemmFragmentAB16x16CDNATransposed(const int k_pack) {
+  ICHECK(k_pack >= 1) << "k_pack must be >= 1";

80-87: Add k_pack check for 16x32 path.

Apply:

 Fragment makeGemmFragmentAB16x32CDNA(const int k_pack) {
+  ICHECK(k_pack >= 1) << "k_pack must be >= 1";

89-96: Add k_pack check for 16x32 transposed path.

Apply:

 Fragment makeGemmFragmentAB16x32CDNATransposed(const int k_pack) {
+  ICHECK(k_pack >= 1) << "k_pack must be >= 1";
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b62a0b4 and 83a6157.

📒 Files selected for processing (6)
  • examples/gemm_fp8/example_tilelang_gemm_amd.py (1 hunks)
  • src/layout/gemm_layouts.cc (4 hunks)
  • src/layout/layout.h (1 hunks)
  • src/op/gemm.cc (1 hunks)
  • src/tl_templates/hip/gemm.h (8 hunks)
  • src/tl_templates/hip/hip_fp8.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/tl_templates/hip/hip_fp8.h (1)
src/tl_templates/cuda/cuda_fp8.h (1)
  • fp8_e4_2_t (8-11)
examples/gemm_fp8/example_tilelang_gemm_amd.py (7)
tilelang/utils/tensor.py (1)
  • torch_assert_close (220-312)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/allocate.py (2)
  • alloc_fragment (53-64)
  • alloc_shared (21-36)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (84-152)
src/layout/layout.h (1)
tilelang/tileop/gemm/gemm_base.py (1)
  • k_pack (110-111)
src/tl_templates/hip/gemm.h (1)
tilelang/intrinsics/mfma_layout.py (1)
  • make_mfma_swizzle_layout (130-152)
src/layout/gemm_layouts.cc (2)
tilelang/tileop/gemm/gemm_base.py (1)
  • k_pack (110-111)
src/layout/layout.cc (2)
  • Fragment (274-296)
  • Fragment (298-308)
src/op/gemm.cc (1)
tilelang/tileop/gemm/gemm_base.py (2)
  • A (66-67)
  • trans_A (45-46)
🪛 Ruff (0.12.2)
examples/gemm_fp8/example_tilelang_gemm_amd.py

92-92: Avoid specifying long messages outside the exception class

(TRY003)

🪛 GitHub Actions: CI
examples/gemm_fp8/example_tilelang_gemm_amd.py

[error] 1-1: format.sh reformatted files and exited with code 1. Please review and stage the changes in 'examples/gemm_fp8/example_tilelang_gemm_amd.py'.

🔇 Additional comments (15)
src/tl_templates/hip/hip_fp8.h (1)

23-25: Verify float4→__hip_fp8x4_e4m3_fnuz conversion availability.

The ctor relies on an implicit conversion from float4. Confirm the toolchain provides this; otherwise, gate it or implement explicit per-lane conversion.

Would you like me to add a guarded implementation that converts per-lane with the available HIP FP8 intrinsics?

examples/gemm_fp8/example_tilelang_gemm_amd.py (5)

7-9: Reference impl LGTM.

FP16 matmul with FP32 accumulate is a reasonable baseline for FP8.


44-46: Decorator order check.

autotune outside jit is fine (autotune(jit(func))). Just confirm this matches your expected compilation/execution path.


55-67: CDNA: rs path matches InferLayout constraints.

A in local.fragment + B in shared aligns with the CDNA path that disallows B.local.fragment.


69-86: CDNA: ss path OK.

Both A and B in shared are supported; transpose_B=True + k_pack threading aligns with layout updates.


1-106: CI formatting failed — run format.sh locally and commit the changes

Verification in the sandbox failed: ./format.sh exited with "pip: command not found". Run ./format.sh locally and commit the formatting changes to unblock CI; if it still errors, paste the stdout/stderr from ./format.sh and the output of:
git diff -- examples/gemm_fp8/example_tilelang_gemm_amd.py

src/op/gemm.cc (3)

585-586: kPack threading into A fragment (CDNA) looks correct.

Signature change is respected: makeGemmFragmentACDNA(..., dtype_bits, kPack, trans_A).


426-434: Lowering: kPack emission for CDNA path LGTM; keep wg_wait optional.

No action needed.


461-605: Sanity scan — B local.fragment with CDNA: inconclusive; verify tests/examples

Found B allocated as local.fragment in tests/examples (e.g. testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py:301, testing/python/kernel/test_tilelang_kernel_gemm.py:330, examples/dequantize_gemm/*). Quick pass did not reveal an obvious src/ kernel that combines TargetIsCDNA with B being local.fragment. Run a targeted check (e.g. grep for "TargetIsCDNA.*alloc_fragment.*B" or review CI/target matrix) to confirm none of these kernels are built/run for CDNA.

src/layout/layout.h (1)

154-158: Confirmed: all call sites updated to pass k_pack.
Only usage found: src/op/gemm.cc — calls makeGemmFragmentACDNA(..., kPack, ...); declaration (src/layout/layout.h) and definition (src/layout/gemm_layouts.cc) include the k_pack parameter.

src/tl_templates/hip/gemm.h (3)

105-112: Index maps now scale with vec_size·kPack — LGTM.


125-126: Swizzle layout generalized to vec_size — matches Python impl.

The C++ mirrors tilelang/intrinsics/mfma_layout.py. Looks good.

Also applies to: 151-154


54-66: Gate FP8 path by arch if needed.

Confirm __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8 is available on targeted gfx IPs; otherwise guard by __HIP_ARCH_HAS_FP8 or similar.

src/layout/gemm_layouts.cc (2)

245-276: ACDNA fragment: mfma_k logic and tiling look correct.

Using mfma_k = k_pack*(16|32) and switching AB16x16 vs AB16x32 by element_size matches HIP MFMA shapes. Repeat/divisor math is consistent.


642-645: CDNA AB layout now always uses matrix-core swizzle — LGTM.

This unifies behavior and leverages kPack; good.

Comment on lines 154 to 158
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
bool transposed = false);
const int k_pack, bool transposed = false);

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Public API now requires k_pack; consider defaulting to 1 to minimize breakage.

Adding a mandatory k_pack breaks all external callers. If backward-compat isn’t a concern, fine; otherwise, default it to 1 so existing FP16/BF16 paths keep compiling.

Apply:

-Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
-                               const int block_k, const int warp_m,
-                               const int warp_n, const int element_size,
-                               const int k_pack, bool transposed = false);
+Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
+                               const int block_k, const int warp_m,
+                               const int warp_n, const int element_size,
+                               const int k_pack = 1, bool transposed = false);
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
bool transposed = false);
const int k_pack, bool transposed = false);
Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
const int block_k, const int warp_m,
const int warp_n, const int element_size,
const int k_pack = 1, bool transposed = false);
🤖 Prompt for AI Agents
In src/layout/layout.h around lines 154-158, the newly added mandatory k_pack
parameter breaks existing external callers; make k_pack optional by giving it a
default value of 1 in the function declaration so previous FP16/BF16 call sites
continue to compile. Update the prototype to set k_pack = 1 and ensure the
corresponding implementation/definition (and any inline/overload declarations)
uses the same default or handles the absent argument consistently.

Comment on lines +54 to +66
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template <> struct MfmaTraits<fp8_e4_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
AccType *c) {
int64_t a_val = *reinterpret_cast<const int64_t *>(a);
int64_t b_val = *reinterpret_cast<const int64_t *>(b);
*c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
}
};
#endif

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

FP8 MFMA: avoid potential misaligned 64-bit loads and aliasing UB.

Reinterpreting fp8_e4_t* to int64_t* may misalign and violates strict aliasing. Use memcpy to scalars or ensure 8B alignment on sources.

Apply:

-  static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
-                                AccType *c) {
-    int64_t a_val = *reinterpret_cast<const int64_t *>(a);
-    int64_t b_val = *reinterpret_cast<const int64_t *>(b);
-    *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
-  }
+  static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
+                                AccType *c) {
+    int64_t a_val, b_val;
+    __builtin_memcpy(&a_val, a, sizeof(a_val));
+    __builtin_memcpy(&b_val, b, sizeof(b_val));
+    *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
+  }

And include cstring once per TU:

 #include "common.h"
+#include <cstring>
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template <> struct MfmaTraits<fp8_e4_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
AccType *c) {
int64_t a_val = *reinterpret_cast<const int64_t *>(a);
int64_t b_val = *reinterpret_cast<const int64_t *>(b);
*c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
}
};
#endif
#include "common.h"
#include <cstring>
Suggested change
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template <> struct MfmaTraits<fp8_e4_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
AccType *c) {
int64_t a_val = *reinterpret_cast<const int64_t *>(a);
int64_t b_val = *reinterpret_cast<const int64_t *>(b);
*c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
}
};
#endif
#if defined(HIP_FP8_ENABLED)
// Specialization for fp8_e4_t
template <> struct MfmaTraits<fp8_e4_t> {
template <typename AccType>
static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
AccType *c) {
int64_t a_val, b_val;
__builtin_memcpy(&a_val, a, sizeof(a_val));
__builtin_memcpy(&b_val, b, sizeof(b_val));
*c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
}
};
#endif

Comment on lines +9 to +15
struct fp8_e4_4_t {
union {
__hip_fp8x4_e4m3_fnuz data;
struct {
fp8_e4_t x, y, z, w;
};
};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Union-based type punning risks UB; add layout guards or switch to a safer representation.

Reading a different union member than the last-written one is UB in C++. Since generated code will access x/y/z/w and conversions read/write data, add size/alignment guards and use byte-wise copies to/from the underlying type to avoid aliasing issues.

-struct fp8_e4_4_t {
-  union {
-    __hip_fp8x4_e4m3_fnuz data;
-    struct {
-      fp8_e4_t x, y, z, w;
-    };
-  };
+struct fp8_e4_4_t {
+  // Storage
+  __hip_fp8x4_e4m3_fnuz data;
+  // Named accessors
+  __device__ fp8_e4_t& x() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 0); }
+  __device__ fp8_e4_t& y() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 1); }
+  __device__ fp8_e4_t& z() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 2); }
+  __device__ fp8_e4_t& w() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 3); }
+  __device__ const fp8_e4_t& x() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 0); }
+  __device__ const fp8_e4_t& y() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 1); }
+  __device__ const fp8_e4_t& z() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 2); }
+  __device__ const fp8_e4_t& w() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 3); }
 
   // Default constructor
   __device__ fp8_e4_4_t() = default;
 
   // Constructor from __hip_fp8x4_e4m3_fnuz
   __device__ fp8_e4_4_t(const __hip_fp8x4_e4m3_fnuz &val) : data(val) {}
 
-  // Constructor from float4
-  __device__ fp8_e4_4_t(const float4 &val) : data(val) {}
+  // Note: float4 ctor requires a well-defined conversion; see comment below.
 
   // Conversion operator to __hip_fp8x4_e4m3_fnuz
   __device__ operator __hip_fp8x4_e4m3_fnuz() const { return data; }
 
   // Assignment operator
   __device__ fp8_e4_4_t &operator=(const __hip_fp8x4_e4m3_fnuz &val) {
     data = val;
     return *this;
   }
 };

Follow-up:

  • If you must keep direct field syntax (obj.x), we can wrap with proxy references but that’s more invasive. The accessor approach above keeps aliasing-safe semantics.
  • Add guards near the type definitions:
    • static_assert(sizeof(fp8_e4_t) == 1, "fp8_e4_t must be 1 byte");
    • static_assert(sizeof(__hip_fp8x4_e4m3_fnuz) == 4, "__hip_fp8x4_e4m3_fnuz must be 4 bytes");

Also applies to: 18-34

🤖 Prompt for AI Agents
In src/tl_templates/hip/hip_fp8.h around lines 9-15 (and similarly 18-34), the
union-based type punning between __hip_fp8x4_e4m3_fnuz and the fp8_e4_t fields
is UB; replace unsafe aliasing by adding layout guards and using aliasing-safe
copies or accessors: add static_asserts that sizeof(fp8_e4_t)==1 and
sizeof(__hip_fp8x4_e4m3_fnuz)==4 and matching alignas checks, then remove direct
cross-member reads/writes and implement small inline accessor functions (or use
std::memcpy to copy bytes to/from the underlying __hip_fp8x4_e4m3_fnuz storage)
so all conversions use byte-wise copies instead of reading a different union
member; if you must keep field-like accessors, implement proxy getter/setter
functions that perform the memcpy to maintain aliasing-safe semantics.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd.py (2)

27-50: Use a filtered comprehension and prune invalid k_pack combos

Also include a pipelined option in num_stages and skip configs where block_K % k_pack != 0.

 def get_configs():
     block_Ms = [32, 64, 128]
     block_Ns = [32, 64, 128]
     block_Ks = [64, 128]
-    num_stages = [0]
+    num_stages = [0, 2]
     num_threads = [256]
     k_packs = [1, 2]
     gemm_types = ["ss", "rs"]
-
-    valid_configs = []
-
-    for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks,
-                                                               num_stages, num_threads, k_packs,
-                                                               gemm_types):
-        valid_configs.append({
-            "block_M": m,
-            "block_N": n,
-            "block_K": k,
-            "num_stages": stages,
-            "num_threads": t,
-            "k_pack": kp,
-            "gemm_type": gemm_type,
-        })
-    return valid_configs
+    return [
+        {
+            "block_M": m,
+            "block_N": n,
+            "block_K": k,
+            "num_stages": stages,
+            "num_threads": t,
+            "k_pack": kp,
+            "gemm_type": gemm_type,
+        }
+        for m, n, k, stages, t, kp, gemm_type in itertools.product(
+            block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types
+        )
+        if k % kp == 0
+    ]

64-115: Deduplicate rs/ss kernels by selecting A’s storage based on gemm_type

Both bodies are identical aside from A_local vs A_shared. Consolidate to one prim_func for maintainability and to avoid drift.

-    @T.prim_func
-    def gemm_fp8_rs(
-            A: T.Tensor((M, K), dtype),
-            B: T.Tensor((N, K), dtype),
-            C: T.Tensor((M, N), accum_dtype),
-    ):
-        with T.Kernel(
-                T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
-            A_local = T.alloc_fragment((block_M, block_K), dtype)
-            B_shared = T.alloc_shared((block_N, block_K), dtype)
-            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
-
-            T.clear(C_local)
-            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
-                T.copy(A[by * block_M, k * block_K], A_local)
-                T.copy(B[bx * block_N, k * block_K], B_shared)
-                T.gemm(
-                    A_local,
-                    B_shared,
-                    C_local,
-                    transpose_B=True,
-                    k_pack=k_pack,
-                    policy=T.GemmWarpPolicy.FullRow)
-
-            T.copy(C_local, C[by * block_M, bx * block_N])
-
-    @T.prim_func
-    def gemm_fp8_ss(
-            A: T.Tensor((M, K), dtype),
-            B: T.Tensor((N, K), dtype),
-            C: T.Tensor((M, N), accum_dtype),
-    ):
-        with T.Kernel(
-                T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
-            A_shared = T.alloc_shared((block_M, block_K), dtype)
-            B_shared = T.alloc_shared((block_N, block_K), dtype)
-            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
-
-            T.clear(C_local)
-            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
-                T.copy(A[by * block_M, k * block_K], A_shared)
-                T.copy(B[bx * block_N, k * block_K], B_shared)
-                T.gemm(
-                    A_shared,
-                    B_shared,
-                    C_local,
-                    transpose_B=True,
-                    k_pack=k_pack,
-                    policy=T.GemmWarpPolicy.FullRow)
-
-            T.copy(C_local, C[by * block_M, bx * block_N])
+    @T.prim_func
+    def gemm_fp8(
+            A: T.Tensor((M, K), dtype),
+            B: T.Tensor((N, K), dtype),
+            C: T.Tensor((M, N), accum_dtype),
+    ):
+        with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
+            if gemm_type == "rs":
+                A_mem = T.alloc_fragment((block_M, block_K), dtype)
+            elif gemm_type == "ss":
+                A_mem = T.alloc_shared((block_M, block_K), dtype)
+            else:
+                raise ValueError("Invalid gemm_type")
+            B_shared = T.alloc_shared((block_N, block_K), dtype)
+            C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
+
+            T.clear(C_local)
+            for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
+                T.copy(A[by * block_M, k * block_K], A_mem, coalesced_width=16)
+                T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=16)
+                T.gemm(A_mem, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
+            T.copy(C_local, C[by * block_M, bx * block_N])

And below:

-    if gemm_type == "ss":
-        return gemm_fp8_ss
-    elif gemm_type == "rs":
-        return gemm_fp8_rs
-    else:
-        raise ValueError(f"Invalid gemm_type: {gemm_type}")
+    return gemm_fp8
🧹 Nitpick comments (5)
examples/gemm_fp8/example_tilelang_gemm_amd.py (5)

12-14: Check the whole tensor during tuning, not just the first row

Limiting validation to C[0] risks missing layout-wide bugs. Consider checking the full tensor (or a stratified sample) to keep autotuning trustworthy.

-def manual_check_prog(C, C_ref):
-    torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1)
+def manual_check_prog(C, C_ref):
+    # Full-tensor check; relax mismatch ratio to keep tuning fast.
+    torch_assert_close(C, C_ref, rtol=0.01, atol=0.1, max_mismatched_ratio=0.01)

60-63: Be explicit: use float32 for accumulators

"float" may alias to float32, but "float32" removes ambiguity and matches other APIs.

-    accum_dtype = "float"
+    accum_dtype = "float32"

78-80: Hint vector width for better global-memory coalescing

Explicit coalesced_width (e.g., 16 or 32) can help the HIP path choose better vectorized LD/ST on CDNA.

If you keep separate kernels, minimally:

-                T.copy(A[by * block_M, k * block_K], A_local)
-                T.copy(B[bx * block_N, k * block_K], B_shared)
+                T.copy(A[by * block_M, k * block_K], A_local, coalesced_width=16)
+                T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=16)

Also applies to: 104-106


116-121: Minor: TRY003 and clearer error path

Ruff flagged TRY003 here. If you adopt the single-prim_func pattern above, this line disappears; otherwise, consider a constant message or pre-validating gemm_type before decoration.


124-134: Stabilize the test and broaden coverage slightly

Seed RNG and drop the trailing tilde in the message.

 def test_gemm_fp8(M, N, K):
-    kernel = fp8_matmul(M, N, K)
+    torch.manual_seed(0)
+    kernel = fp8_matmul(M, N, K)
@@
-    print("passed~")
+    print("passed")

Optionally add an odd-dimension smoke test under __main__:

 if __name__ == "__main__":
     test_gemm_fp8(512, 512, 512)
+    test_gemm_fp8(513, 517, 509)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 83a6157 and e5c6642.

📒 Files selected for processing (1)
  • examples/gemm_fp8/example_tilelang_gemm_amd.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd.py (7)
tilelang/utils/tensor.py (1)
  • torch_assert_close (220-312)
tilelang/autotuner/tuner.py (1)
  • autotune (692-785)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/allocate.py (2)
  • alloc_fragment (53-64)
  • alloc_shared (21-36)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/copy.py (1)
  • copy (84-152)
🪛 Ruff (0.12.2)
examples/gemm_fp8/example_tilelang_gemm_amd.py

121-121: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
  • GitHub Check: build-test-nvidia
  • GitHub Check: bot-task
🔇 Additional comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd.py (2)

16-24: Verify device selection on ROCm runners

This example targets AMD; PyTorch on ROCm typically still uses cuda device type, but environments differ. Please confirm this runs on your ROCm CI/hosts; otherwise consider parameterizing the device.


70-89: Edge tiles: please confirm OOB safety on non-multiple shapes

The unguarded T.copy of full tiles relies on TileLang to handle partial tiles safely. Please run a smoke test with non-multiples (e.g., M=513, N=517, K=509) to confirm no OOB or correctness regressions.

Also applies to: 96-115

@LeiWang1999
Copy link
Member

Awesome Contribution! Merged :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants