-
Couldn't load subscription status.
- Fork 284
[AMD] support fp8 T.gemm #804
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
Note Other AI code review bot(s) detectedCodeRabbit has detected other AI code review bot(s) in this pull request and will avoid duplicating their findings in the review comments. This may lead to a less comprehensive review. WalkthroughAdds an FP8 GEMM example for AMD using TileLang, introduces FP8 MFMA support and vectorization in HIP templates, threads k_pack through CDNA GEMM layouts and fragment builders, updates public layout APIs to accept k_pack, and adjusts GemmNode layout inference to pass kPack and remove B local.fragment binding. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
actor User
participant Example as example_tilelang_gemm_amd.py
participant TL as TileLang (autotune/jit)
participant Kernel as gemm_fp8_{ss|rs}
participant GPU as AMD CDNA MFMA
User->>Example: test_gemm_fp8(M,N,K)
Example->>TL: fp8_matmul(..., gemm_type, k_pack)
TL->>Kernel: select & compile kernel (tiling, k_pack)
Kernel->>GPU: launch with FP8 inputs (mfma fp8 path)
GPU-->>Kernel: compute results
Kernel-->>Example: return C
Example->>Example: compare with ref_program -> pass/fail
sequenceDiagram
autonumber
participant GemmNode as GemmNode::InferLayout (CDNA)
participant Layout as Fragment/Layout builders
GemmNode->>Layout: makeGemmFragmentACDNA(..., dtype_bits, kPack, trans_A)
note right of Layout: select k_pack-aware AB/CDNA fragment variant<br/>matrix-core swizzle uses kPack
GemmNode-->>GemmNode: B local.fragment binding for CDNA removed
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
Suggested reviewers
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Please see the documentation for more information. Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Pass/fail criteria: All breaking changes to public APIs, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints must be documented in the "Breaking Change" section of the PR description and in CHANGELOG.md. Exclude purely internal or private changes (e.g., code not exported from package entry points or explicitly marked as internal).Please share your feedback with us on this Discord post. Pre-merge checks (3 passed)✅ Passed checks (3 passed)
✨ Finishing touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @txs19991, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request enables 8-bit floating point (FP8) matrix multiplication on AMD GPUs by integrating specific hardware instructions and updating the underlying memory layout and fragment generation logic. It also includes a new example to demonstrate and validate the FP8 GEMM functionality within the tilelang framework.
Highlights
- FP8 GEMM Support: Introduced support for FP8 (8-bit floating point) General Matrix Multiply (GEMM) operations, specifically targeting AMD GPUs. This includes the integration of the
__builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8instruction for efficient FP8 computations. - New FP8 GEMM Example: Added a new Python example (
examples/gemm_fp8/example_tilelang_gemm_amd.py) demonstrating FP8 GEMM with autotuning capabilities and reference program validation, showcasing both 'ss' and 'rs' GEMM types. - Generalized Layout and Fragment Generation: Modified existing GEMM fragment and layout generation functions (
makeGemmFragmentAB16x16CDNA,makeGemmFragmentACDNA, etc.) to be more flexible withk_packand element sizes, and introduced new 16x32 fragments to support FP8 specific MFMA instructions. Simplified themakeGemmABLayoutCDNAfunction by removing conditional padding logic. - FP8 Type Utility: Defined
HIP_FP8_ENABLEDand introduced a utility wrapperfp8_e4_4_tfor__hip_fp8x4_e4m3_fnuzto provide convenient member access and improve usability of FP8 types in the codebase.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.
| Feature | Command | Description |
|---|---|---|
| Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
| Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
| Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
| Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request adds support for fp8 GEMM on AMD GPUs, introducing a new example, updating layout definitions, and adding template specializations for HIP. The changes appear to correctly enable the new functionality. My review includes a few suggestions to enhance code quality and safety. Specifically, I've recommended refactoring in the new Python example to improve its structure and reduce code duplication. Additionally, I've identified a potential strict-aliasing issue in a C++ template and proposed a safer implementation.
| valid_configs = [] | ||
|
|
||
| for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types): | ||
| valid_configs.append({ | ||
| "block_M": m, | ||
| "block_N": n, | ||
| "block_K": k, | ||
| "num_stages": stages, | ||
| "num_threads": t, | ||
| "k_pack": kp, | ||
| "gemm_type": gemm_type, | ||
| }) | ||
| return valid_configs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The for-loop used to build valid_configs can be written more concisely and idiomatically as a list comprehension. This improves readability and is a more common pattern in Python.
return [{
"block_M": m,
"block_N": n,
"block_K": k,
"num_stages": stages,
"num_threads": t,
"k_pack": kp,
"gemm_type": gemm_type,
} for m, n, k, stages, t, kp, gemm_type in itertools.product(
block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types
)]| @T.prim_func | ||
| def gemm_fp8_rs( | ||
| A: T.Tensor((M, K), dtype), | ||
| B: T.Tensor((N, K), dtype), | ||
| C: T.Tensor((M, N), accum_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): | ||
| A_local = T.alloc_fragment((block_M, block_K), dtype) | ||
| B_shared = T.alloc_shared((block_N, block_K), dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
|
|
||
| T.clear(C_local) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): | ||
| T.copy(A[by * block_M, k * block_K], A_local) | ||
| T.copy(B[bx * block_N, k * block_K], B_shared) | ||
| T.gemm(A_local, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
||
| @T.prim_func | ||
| def gemm_fp8_ss( | ||
| A: T.Tensor((M, K), dtype), | ||
| B: T.Tensor((N, K), dtype), | ||
| C: T.Tensor((M, N), accum_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): | ||
| A_shared = T.alloc_shared((block_M, block_K), dtype) | ||
| B_shared = T.alloc_shared((block_N, block_K), dtype) | ||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
|
|
||
| T.clear(C_local) | ||
| for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| T.copy(B[bx * block_N, k * block_K], B_shared) | ||
| T.gemm(A_shared, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) | ||
|
|
||
| T.copy(C_local, C[by * block_M, bx * block_N]) | ||
|
|
||
| if gemm_type == "ss": | ||
| return gemm_fp8_ss | ||
| elif gemm_type == "rs": | ||
| return gemm_fp8_rs | ||
| else: | ||
| raise ValueError(f"Invalid gemm_type: {gemm_type}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The functions gemm_fp8_rs and gemm_fp8_ss are almost identical, differing only in whether matrix A is stored in a local fragment (A_local) or shared memory (A_shared). You can avoid this code duplication by defining a single T.prim_func and using a conditional statement to allocate memory for A based on the gemm_type. This will make the code more maintainable.
if gemm_type not in ["ss", "rs"]:
raise ValueError(f"Invalid gemm_type: {gemm_type}")
@T.prim_func
def gemm_fp8(
A: T.Tensor((M, K), dtype),
B: T.Tensor((N, K), dtype),
C: T.Tensor((M, N), accum_dtype),
):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by):
if gemm_type == "rs":
A_mem = T.alloc_fragment((block_M, block_K), dtype)
else: # ss
A_mem = T.alloc_shared((block_M, block_K), dtype)
B_shared = T.alloc_shared((block_N, block_K), dtype)
C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
T.clear(C_local)
for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_mem)
T.copy(B[bx * block_N, k * block_K], B_shared)
T.gemm(A_mem, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow)
T.copy(C_local, C[by * block_M, bx * block_N])
return gemm_fp8| int64_t a_val = *reinterpret_cast<const int64_t *>(a); | ||
| int64_t b_val = *reinterpret_cast<const int64_t *>(b); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using reinterpret_cast followed by dereferencing can lead to strict aliasing violations, which is undefined behavior in C++. Although it might work with current compilers, it's safer to use memcpy to copy the bytes, which is guaranteed to be safe and well-defined.
int64_t a_val, b_val;
memcpy(&a_val, a, sizeof(int64_t));
memcpy(&b_val, b, sizeof(int64_t));There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (4)
src/tl_templates/hip/hip_fp8.h (2)
46-55: Bit-packing uses signed shifts and strict-aliasing—make it well-defined.Shifting signed char is UB; reinterpreting an int as fp8_e4_4_t is also UB. Use uint8_t/uint32_t and byte-wise stores into the underlying vector.
- // reinterpret the 4 fp8_e4_t values to signed char value and shift - signed char x_char = *reinterpret_cast<signed char *>(&x); - signed char y_char = *reinterpret_cast<signed char *>(&y); - signed char z_char = *reinterpret_cast<signed char *>(&z); - signed char w_char = *reinterpret_cast<signed char *>(&w); - int res = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; - return *reinterpret_cast<fp8_e4_4_t *>(&res); + // Extract raw bytes safely + const unsigned char x_u8 = *reinterpret_cast<const unsigned char*>(&x); + const unsigned char y_u8 = *reinterpret_cast<const unsigned char*>(&y); + const unsigned char z_u8 = *reinterpret_cast<const unsigned char*>(&z); + const unsigned char w_u8 = *reinterpret_cast<const unsigned char*>(&w); + fp8_e4_4_t out; + unsigned char* dst = reinterpret_cast<unsigned char*>(&out.data); + dst[0] = x_u8; dst[1] = y_u8; dst[2] = z_u8; dst[3] = w_u8; + return out;
57-74: Same UB in fp8_e4_8_t builder; construct via per-byte stores.Avoid signed shifts and type-punning across unrelated types.
- signed char x_char = *reinterpret_cast<signed char *>(&x); - signed char y_char = *reinterpret_cast<signed char *>(&y); - signed char z_char = *reinterpret_cast<signed char *>(&z); - signed char w_char = *reinterpret_cast<signed char *>(&w); - signed char v_char = *reinterpret_cast<signed char *>(&v); - signed char u_char = *reinterpret_cast<signed char *>(&u); - signed char t_char = *reinterpret_cast<signed char *>(&t); - signed char s_char = *reinterpret_cast<signed char *>(&s); - int a = (w_char << 24) | (z_char << 16) | (y_char << 8) | x_char; - int b = (s_char << 24) | (t_char << 16) | (u_char << 8) | v_char; - fp8_e4_8_t res; - res.x = *reinterpret_cast<fp8_e4_4_t *>(&a); - res.y = *reinterpret_cast<fp8_e4_4_t *>(&b); - return res; + const unsigned char xs[8] = { + *reinterpret_cast<const unsigned char*>(&x), + *reinterpret_cast<const unsigned char*>(&y), + *reinterpret_cast<const unsigned char*>(&z), + *reinterpret_cast<const unsigned char*>(&w), + *reinterpret_cast<const unsigned char*>(&v), + *reinterpret_cast<const unsigned char*>(&u), + *reinterpret_cast<const unsigned char*>(&t), + *reinterpret_cast<const unsigned char*>(&s), + }; + fp8_e4_8_t res; + unsigned char* a = reinterpret_cast<unsigned char*>(&res.x.data); + unsigned char* b = reinterpret_cast<unsigned char*>(&res.y.data); + a[0]=xs[0]; a[1]=xs[1]; a[2]=xs[2]; a[3]=xs[3]; + b[0]=xs[4]; b[1]=xs[5]; b[2]=xs[6]; b[3]=xs[7]; + return res;src/op/gemm.cc (1)
590-599: CDNA now forbids B.local.fragment—emit a clearer error.Replace ICHECK(0) with a message that instructs users to place B in shared for CDNA.
- } else { - ICHECK(0); - } + } else { + ICHECK(false) << "CDNA GEMM requires B in shared/shared.dyn; got scope: " + << B.scope(); + }src/layout/gemm_layouts.cc (1)
418-441: Guard against maxPhase=0; add include.If continuous < vecSize, innerDimLength/vecSize == 0 ⇒ modulo by 0 in phase. Add checks and include for std::min/max.
Apply:
#include <cmath> +#include <algorithm>And within makeMatrixCoreSwizzleLayout:
const int vecSize = (64 / element_size) * kPack; const int innerDimLength = continuous; const int typeWidthInBit = element_size; + ICHECK_GT(vecSize, 0); + ICHECK_GE(innerDimLength, vecSize) + << "continuous (" << innerDimLength << ") must be >= vecSize (" << vecSize << ")"; + ICHECK_EQ(innerDimLength % vecSize, 0) + << "continuous must be a multiple of vecSize for swizzle";
🧹 Nitpick comments (15)
src/tl_templates/hip/hip_fp8.h (1)
3-3: Don't hard-define feature macros in public headers.Defining HIP_FP8_ENABLED in a header can cause ODR/config skew across TUs. Prefer guarding or moving to build flags.
-#define HIP_FP8_ENABLED 1 +#ifndef HIP_FP8_ENABLED +#define HIP_FP8_ENABLED 1 +#endifexamples/gemm_fp8/example_tilelang_gemm_amd.py (5)
10-12: Strengthen manual check beyond a single element.Sampling more positions (e.g., a few random indices or a small tile) reduces false positives during autotune.
-def manual_check_prog(C, C_ref): - torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1) +def manual_check_prog(C, C_ref): + # Check a few elements to guard against localized errors + torch_assert_close(C[:1, :4], C_ref[:1, :4], rtol=0.01, atol=0.1)
13-20: Bind device from params to avoid device mismatches.Using the incoming param device makes the example portable across CUDA/ROCm installs and multi-GPU setups.
- a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz) + dev = a_param.device + a = (torch.randn(M, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz)
21-43: Config space looks sane; consider pruning incompatible tuples if needed.If MFMA FP8 requires K divisible by 32 (or k_pack×16), you can filter here to cut invalid runs.
87-92: Minor: tighten error.Consider listing valid options in the message for quicker debugging.
- raise ValueError(f"Invalid gemm_type: {gemm_type}") + raise ValueError(f"Invalid gemm_type: {gemm_type}. Expected one of: 'ss', 'rs'.")
95-103: Leverage param devices and match tolerances with utils defaults.Mirror the supply_prog device binding; utils default atol=1e-2, rtol=1e-2—OK.
- a = (torch.randn(M, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz) - b = (torch.randn(N, K, dtype=torch.float16, device='cuda') * 0.01).to(dtype=torch.float8_e4m3fnuz) + dev = torch.device('cuda') + a = (torch.randn(M, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz) + b = (torch.randn(N, K, dtype=torch.float16, device=dev) * 0.01).to(dtype=torch.float8_e4m3fnuz)src/layout/layout.h (1)
166-168: Parameter naming is inconsistent across files (kfactor vs kPack/k_pack).In this header it’s kfactor; in gemm_layouts.cc it’s kPack; in other places it’s k_pack. Pick one (suggest k_pack) for consistency.
src/tl_templates/hip/gemm.h (4)
77-79: New micro_size_k/vec_size are correct; add compile-time guards.Assert integrality assumptions to prevent silent shape bugs.
Apply:
static constexpr int micro_size_x = 16; static constexpr int micro_size_y = 16; static constexpr int micro_size_k = 32 / sizeof(A_type); static constexpr int vec_size = 8 / sizeof(A_type); + static_assert((micro_size_x * micro_size_k) % warp_size == 0, + "local_size_a must be integral"); + static_assert((micro_size_y * micro_size_k) % warp_size == 0, + "local_size_b must be integral");Also applies to: 88-89
175-177: Ensure 8/16B alignment for vectorized packing.Local arrays feed 64-bit (fp8/i8) or 8-byte (f16x4/bf16x4) loads. Add alignment to avoid misaligned vector loads.
Apply:
- A_type A_local[warp_rows * kPack * local_size_a]; - B_type B_local[warp_cols * kPack * local_size_b]; + alignas(16) A_type A_local[warp_rows * kPack * local_size_a]; + alignas(16) B_type B_local[warp_cols * kPack * local_size_b];
241-252: Align B_local in rs path as well.Apply:
- B_type B_local[warp_cols * kPack * local_size_b]; + alignas(16) B_type B_local[warp_cols * kPack * local_size_b];
166-167: Remove unused variable tx.tx is set but never used in both bodies.
Apply:
- auto tx = lane_id;src/layout/gemm_layouts.cc (4)
62-69: Validate k_pack at entry.Add a sanity check to prevent undefined layouts on k_pack <= 0.
Apply:
Fragment makeGemmFragmentAB16x16CDNA(const int k_pack) { + ICHECK(k_pack >= 1) << "k_pack must be >= 1";
71-78: Same k_pack validation for transposed variant.Apply:
Fragment makeGemmFragmentAB16x16CDNATransposed(const int k_pack) { + ICHECK(k_pack >= 1) << "k_pack must be >= 1";
80-87: Add k_pack check for 16x32 path.Apply:
Fragment makeGemmFragmentAB16x32CDNA(const int k_pack) { + ICHECK(k_pack >= 1) << "k_pack must be >= 1";
89-96: Add k_pack check for 16x32 transposed path.Apply:
Fragment makeGemmFragmentAB16x32CDNATransposed(const int k_pack) { + ICHECK(k_pack >= 1) << "k_pack must be >= 1";
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (6)
examples/gemm_fp8/example_tilelang_gemm_amd.py(1 hunks)src/layout/gemm_layouts.cc(4 hunks)src/layout/layout.h(1 hunks)src/op/gemm.cc(1 hunks)src/tl_templates/hip/gemm.h(8 hunks)src/tl_templates/hip/hip_fp8.h(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
src/tl_templates/hip/hip_fp8.h (1)
src/tl_templates/cuda/cuda_fp8.h (1)
fp8_e4_2_t(8-11)
examples/gemm_fp8/example_tilelang_gemm_amd.py (7)
tilelang/utils/tensor.py (1)
torch_assert_close(220-312)tilelang/autotuner/tuner.py (1)
autotune(692-785)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/allocate.py (2)
alloc_fragment(53-64)alloc_shared(21-36)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(84-152)
src/layout/layout.h (1)
tilelang/tileop/gemm/gemm_base.py (1)
k_pack(110-111)
src/tl_templates/hip/gemm.h (1)
tilelang/intrinsics/mfma_layout.py (1)
make_mfma_swizzle_layout(130-152)
src/layout/gemm_layouts.cc (2)
tilelang/tileop/gemm/gemm_base.py (1)
k_pack(110-111)src/layout/layout.cc (2)
Fragment(274-296)Fragment(298-308)
src/op/gemm.cc (1)
tilelang/tileop/gemm/gemm_base.py (2)
A(66-67)trans_A(45-46)
🪛 Ruff (0.12.2)
examples/gemm_fp8/example_tilelang_gemm_amd.py
92-92: Avoid specifying long messages outside the exception class
(TRY003)
🪛 GitHub Actions: CI
examples/gemm_fp8/example_tilelang_gemm_amd.py
[error] 1-1: format.sh reformatted files and exited with code 1. Please review and stage the changes in 'examples/gemm_fp8/example_tilelang_gemm_amd.py'.
🔇 Additional comments (15)
src/tl_templates/hip/hip_fp8.h (1)
23-25: Verify float4→__hip_fp8x4_e4m3_fnuz conversion availability.The ctor relies on an implicit conversion from float4. Confirm the toolchain provides this; otherwise, gate it or implement explicit per-lane conversion.
Would you like me to add a guarded implementation that converts per-lane with the available HIP FP8 intrinsics?
examples/gemm_fp8/example_tilelang_gemm_amd.py (5)
7-9: Reference impl LGTM.FP16 matmul with FP32 accumulate is a reasonable baseline for FP8.
44-46: Decorator order check.autotune outside jit is fine (autotune(jit(func))). Just confirm this matches your expected compilation/execution path.
55-67: CDNA: rs path matches InferLayout constraints.A in local.fragment + B in shared aligns with the CDNA path that disallows B.local.fragment.
69-86: CDNA: ss path OK.Both A and B in shared are supported; transpose_B=True + k_pack threading aligns with layout updates.
1-106: CI formatting failed — run format.sh locally and commit the changesVerification in the sandbox failed: ./format.sh exited with "pip: command not found". Run ./format.sh locally and commit the formatting changes to unblock CI; if it still errors, paste the stdout/stderr from ./format.sh and the output of:
git diff -- examples/gemm_fp8/example_tilelang_gemm_amd.pysrc/op/gemm.cc (3)
585-586: kPack threading into A fragment (CDNA) looks correct.Signature change is respected: makeGemmFragmentACDNA(..., dtype_bits, kPack, trans_A).
426-434: Lowering: kPack emission for CDNA path LGTM; keep wg_wait optional.No action needed.
461-605: Sanity scan — B local.fragment with CDNA: inconclusive; verify tests/examplesFound B allocated as local.fragment in tests/examples (e.g. testing/python/tilelibrary/test_tilelang_tilelibrary_gemm.py:301, testing/python/kernel/test_tilelang_kernel_gemm.py:330, examples/dequantize_gemm/*). Quick pass did not reveal an obvious src/ kernel that combines TargetIsCDNA with B being local.fragment. Run a targeted check (e.g. grep for "TargetIsCDNA.*alloc_fragment.*B" or review CI/target matrix) to confirm none of these kernels are built/run for CDNA.
src/layout/layout.h (1)
154-158: Confirmed: all call sites updated to pass k_pack.
Only usage found: src/op/gemm.cc — calls makeGemmFragmentACDNA(..., kPack, ...); declaration (src/layout/layout.h) and definition (src/layout/gemm_layouts.cc) include the k_pack parameter.src/tl_templates/hip/gemm.h (3)
105-112: Index maps now scale with vec_size·kPack — LGTM.
125-126: Swizzle layout generalized to vec_size — matches Python impl.The C++ mirrors tilelang/intrinsics/mfma_layout.py. Looks good.
Also applies to: 151-154
54-66: Gate FP8 path by arch if needed.Confirm __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8 is available on targeted gfx IPs; otherwise guard by __HIP_ARCH_HAS_FP8 or similar.
src/layout/gemm_layouts.cc (2)
245-276: ACDNA fragment: mfma_k logic and tiling look correct.Using mfma_k = k_pack*(16|32) and switching AB16x16 vs AB16x32 by element_size matches HIP MFMA shapes. Repeat/divisor math is consistent.
642-645: CDNA AB layout now always uses matrix-core swizzle — LGTM.This unifies behavior and leverages kPack; good.
| Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, | ||
| const int block_k, const int warp_m, | ||
| const int warp_n, const int element_size, | ||
| bool transposed = false); | ||
| const int k_pack, bool transposed = false); | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Public API now requires k_pack; consider defaulting to 1 to minimize breakage.
Adding a mandatory k_pack breaks all external callers. If backward-compat isn’t a concern, fine; otherwise, default it to 1 so existing FP16/BF16 paths keep compiling.
Apply:
-Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
- const int block_k, const int warp_m,
- const int warp_n, const int element_size,
- const int k_pack, bool transposed = false);
+Fragment makeGemmFragmentACDNA(const int block_m, const int block_n,
+ const int block_k, const int warp_m,
+ const int warp_n, const int element_size,
+ const int k_pack = 1, bool transposed = false);📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, | |
| const int block_k, const int warp_m, | |
| const int warp_n, const int element_size, | |
| bool transposed = false); | |
| const int k_pack, bool transposed = false); | |
| Fragment makeGemmFragmentACDNA(const int block_m, const int block_n, | |
| const int block_k, const int warp_m, | |
| const int warp_n, const int element_size, | |
| const int k_pack = 1, bool transposed = false); |
🤖 Prompt for AI Agents
In src/layout/layout.h around lines 154-158, the newly added mandatory k_pack
parameter breaks existing external callers; make k_pack optional by giving it a
default value of 1 in the function declaration so previous FP16/BF16 call sites
continue to compile. Update the prototype to set k_pack = 1 and ensure the
corresponding implementation/definition (and any inline/overload declarations)
uses the same default or handles the absent argument consistently.
| #if defined(HIP_FP8_ENABLED) | ||
| // Specialization for fp8_e4_t | ||
| template <> struct MfmaTraits<fp8_e4_t> { | ||
| template <typename AccType> | ||
| static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a, | ||
| AccType *c) { | ||
| int64_t a_val = *reinterpret_cast<const int64_t *>(a); | ||
| int64_t b_val = *reinterpret_cast<const int64_t *>(b); | ||
| *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0); | ||
| } | ||
| }; | ||
| #endif | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
FP8 MFMA: avoid potential misaligned 64-bit loads and aliasing UB.
Reinterpreting fp8_e4_t* to int64_t* may misalign and violates strict aliasing. Use memcpy to scalars or ensure 8B alignment on sources.
Apply:
- static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
- AccType *c) {
- int64_t a_val = *reinterpret_cast<const int64_t *>(a);
- int64_t b_val = *reinterpret_cast<const int64_t *>(b);
- *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
- }
+ static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a,
+ AccType *c) {
+ int64_t a_val, b_val;
+ __builtin_memcpy(&a_val, a, sizeof(a_val));
+ __builtin_memcpy(&b_val, b, sizeof(b_val));
+ *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0);
+ }And include cstring once per TU:
#include "common.h"
+#include <cstring>📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #if defined(HIP_FP8_ENABLED) | |
| // Specialization for fp8_e4_t | |
| template <> struct MfmaTraits<fp8_e4_t> { | |
| template <typename AccType> | |
| static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a, | |
| AccType *c) { | |
| int64_t a_val = *reinterpret_cast<const int64_t *>(a); | |
| int64_t b_val = *reinterpret_cast<const int64_t *>(b); | |
| *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0); | |
| } | |
| }; | |
| #endif | |
| #include "common.h" | |
| #include <cstring> |
| #if defined(HIP_FP8_ENABLED) | |
| // Specialization for fp8_e4_t | |
| template <> struct MfmaTraits<fp8_e4_t> { | |
| template <typename AccType> | |
| static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a, | |
| AccType *c) { | |
| int64_t a_val = *reinterpret_cast<const int64_t *>(a); | |
| int64_t b_val = *reinterpret_cast<const int64_t *>(b); | |
| *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0); | |
| } | |
| }; | |
| #endif | |
| #if defined(HIP_FP8_ENABLED) | |
| // Specialization for fp8_e4_t | |
| template <> struct MfmaTraits<fp8_e4_t> { | |
| template <typename AccType> | |
| static TL_DEVICE void mfma_op(const fp8_e4_t *b, const fp8_e4_t *a, | |
| AccType *c) { | |
| int64_t a_val, b_val; | |
| __builtin_memcpy(&a_val, a, sizeof(a_val)); | |
| __builtin_memcpy(&b_val, b, sizeof(b_val)); | |
| *c = __builtin_amdgcn_mfma_f32_16x16x32_fp8_fp8(b_val, a_val, *c, 0, 0, 0); | |
| } | |
| }; | |
| #endif |
| struct fp8_e4_4_t { | ||
| union { | ||
| __hip_fp8x4_e4m3_fnuz data; | ||
| struct { | ||
| fp8_e4_t x, y, z, w; | ||
| }; | ||
| }; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🛠️ Refactor suggestion
Union-based type punning risks UB; add layout guards or switch to a safer representation.
Reading a different union member than the last-written one is UB in C++. Since generated code will access x/y/z/w and conversions read/write data, add size/alignment guards and use byte-wise copies to/from the underlying type to avoid aliasing issues.
-struct fp8_e4_4_t {
- union {
- __hip_fp8x4_e4m3_fnuz data;
- struct {
- fp8_e4_t x, y, z, w;
- };
- };
+struct fp8_e4_4_t {
+ // Storage
+ __hip_fp8x4_e4m3_fnuz data;
+ // Named accessors
+ __device__ fp8_e4_t& x() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 0); }
+ __device__ fp8_e4_t& y() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 1); }
+ __device__ fp8_e4_t& z() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 2); }
+ __device__ fp8_e4_t& w() { return *reinterpret_cast<fp8_e4_t*>(reinterpret_cast<unsigned char*>(&data) + 3); }
+ __device__ const fp8_e4_t& x() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 0); }
+ __device__ const fp8_e4_t& y() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 1); }
+ __device__ const fp8_e4_t& z() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 2); }
+ __device__ const fp8_e4_t& w() const { return *reinterpret_cast<const fp8_e4_t*>(reinterpret_cast<const unsigned char*>(&data) + 3); }
// Default constructor
__device__ fp8_e4_4_t() = default;
// Constructor from __hip_fp8x4_e4m3_fnuz
__device__ fp8_e4_4_t(const __hip_fp8x4_e4m3_fnuz &val) : data(val) {}
- // Constructor from float4
- __device__ fp8_e4_4_t(const float4 &val) : data(val) {}
+ // Note: float4 ctor requires a well-defined conversion; see comment below.
// Conversion operator to __hip_fp8x4_e4m3_fnuz
__device__ operator __hip_fp8x4_e4m3_fnuz() const { return data; }
// Assignment operator
__device__ fp8_e4_4_t &operator=(const __hip_fp8x4_e4m3_fnuz &val) {
data = val;
return *this;
}
};Follow-up:
- If you must keep direct field syntax (obj.x), we can wrap with proxy references but that’s more invasive. The accessor approach above keeps aliasing-safe semantics.
- Add guards near the type definitions:
- static_assert(sizeof(fp8_e4_t) == 1, "fp8_e4_t must be 1 byte");
- static_assert(sizeof(__hip_fp8x4_e4m3_fnuz) == 4, "__hip_fp8x4_e4m3_fnuz must be 4 bytes");
Also applies to: 18-34
🤖 Prompt for AI Agents
In src/tl_templates/hip/hip_fp8.h around lines 9-15 (and similarly 18-34), the
union-based type punning between __hip_fp8x4_e4m3_fnuz and the fp8_e4_t fields
is UB; replace unsafe aliasing by adding layout guards and using aliasing-safe
copies or accessors: add static_asserts that sizeof(fp8_e4_t)==1 and
sizeof(__hip_fp8x4_e4m3_fnuz)==4 and matching alignas checks, then remove direct
cross-member reads/writes and implement small inline accessor functions (or use
std::memcpy to copy bytes to/from the underlying __hip_fp8x4_e4m3_fnuz storage)
so all conversions use byte-wise copies instead of reading a different union
member; if you must keep field-like accessors, implement proxy getter/setter
functions that perform the memcpy to maintain aliasing-safe semantics.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd.py (2)
27-50: Use a filtered comprehension and prune invalid k_pack combosAlso include a pipelined option in
num_stagesand skip configs whereblock_K % k_pack != 0.def get_configs(): block_Ms = [32, 64, 128] block_Ns = [32, 64, 128] block_Ks = [64, 128] - num_stages = [0] + num_stages = [0, 2] num_threads = [256] k_packs = [1, 2] gemm_types = ["ss", "rs"] - - valid_configs = [] - - for m, n, k, stages, t, kp, gemm_type in itertools.product(block_Ms, block_Ns, block_Ks, - num_stages, num_threads, k_packs, - gemm_types): - valid_configs.append({ - "block_M": m, - "block_N": n, - "block_K": k, - "num_stages": stages, - "num_threads": t, - "k_pack": kp, - "gemm_type": gemm_type, - }) - return valid_configs + return [ + { + "block_M": m, + "block_N": n, + "block_K": k, + "num_stages": stages, + "num_threads": t, + "k_pack": kp, + "gemm_type": gemm_type, + } + for m, n, k, stages, t, kp, gemm_type in itertools.product( + block_Ms, block_Ns, block_Ks, num_stages, num_threads, k_packs, gemm_types + ) + if k % kp == 0 + ]
64-115: Deduplicate rs/ss kernels by selecting A’s storage based on gemm_typeBoth bodies are identical aside from
A_localvsA_shared. Consolidate to one prim_func for maintainability and to avoid drift.- @T.prim_func - def gemm_fp8_rs( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), - ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): - A_local = T.alloc_fragment((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_local) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_local, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) - - T.copy(C_local, C[by * block_M, bx * block_N]) - - @T.prim_func - def gemm_fp8_ss( - A: T.Tensor((M, K), dtype), - B: T.Tensor((N, K), dtype), - C: T.Tensor((M, N), accum_dtype), - ): - with T.Kernel( - T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): - A_shared = T.alloc_shared((block_M, block_K), dtype) - B_shared = T.alloc_shared((block_N, block_K), dtype) - C_local = T.alloc_fragment((block_M, block_N), accum_dtype) - - T.clear(C_local) - for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): - T.copy(A[by * block_M, k * block_K], A_shared) - T.copy(B[bx * block_N, k * block_K], B_shared) - T.gemm( - A_shared, - B_shared, - C_local, - transpose_B=True, - k_pack=k_pack, - policy=T.GemmWarpPolicy.FullRow) - - T.copy(C_local, C[by * block_M, bx * block_N]) + @T.prim_func + def gemm_fp8( + A: T.Tensor((M, K), dtype), + B: T.Tensor((N, K), dtype), + C: T.Tensor((M, N), accum_dtype), + ): + with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=num_threads) as (bx, by): + if gemm_type == "rs": + A_mem = T.alloc_fragment((block_M, block_K), dtype) + elif gemm_type == "ss": + A_mem = T.alloc_shared((block_M, block_K), dtype) + else: + raise ValueError("Invalid gemm_type") + B_shared = T.alloc_shared((block_N, block_K), dtype) + C_local = T.alloc_fragment((block_M, block_N), accum_dtype) + + T.clear(C_local) + for k in T.Pipelined(T.ceildiv(K, block_K), num_stages=num_stages): + T.copy(A[by * block_M, k * block_K], A_mem, coalesced_width=16) + T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=16) + T.gemm(A_mem, B_shared, C_local, transpose_B=True, k_pack=k_pack, policy=T.GemmWarpPolicy.FullRow) + T.copy(C_local, C[by * block_M, bx * block_N])And below:
- if gemm_type == "ss": - return gemm_fp8_ss - elif gemm_type == "rs": - return gemm_fp8_rs - else: - raise ValueError(f"Invalid gemm_type: {gemm_type}") + return gemm_fp8
🧹 Nitpick comments (5)
examples/gemm_fp8/example_tilelang_gemm_amd.py (5)
12-14: Check the whole tensor during tuning, not just the first rowLimiting validation to
C[0]risks missing layout-wide bugs. Consider checking the full tensor (or a stratified sample) to keep autotuning trustworthy.-def manual_check_prog(C, C_ref): - torch_assert_close(C[0], C_ref[0], rtol=0.01, atol=0.1) +def manual_check_prog(C, C_ref): + # Full-tensor check; relax mismatch ratio to keep tuning fast. + torch_assert_close(C, C_ref, rtol=0.01, atol=0.1, max_mismatched_ratio=0.01)
60-63: Be explicit: use float32 for accumulators
"float"may alias to float32, but"float32"removes ambiguity and matches other APIs.- accum_dtype = "float" + accum_dtype = "float32"
78-80: Hint vector width for better global-memory coalescingExplicit
coalesced_width(e.g., 16 or 32) can help the HIP path choose better vectorized LD/ST on CDNA.If you keep separate kernels, minimally:
- T.copy(A[by * block_M, k * block_K], A_local) - T.copy(B[bx * block_N, k * block_K], B_shared) + T.copy(A[by * block_M, k * block_K], A_local, coalesced_width=16) + T.copy(B[bx * block_N, k * block_K], B_shared, coalesced_width=16)Also applies to: 104-106
116-121: Minor: TRY003 and clearer error pathRuff flagged TRY003 here. If you adopt the single-prim_func pattern above, this line disappears; otherwise, consider a constant message or pre-validating
gemm_typebefore decoration.
124-134: Stabilize the test and broaden coverage slightlySeed RNG and drop the trailing tilde in the message.
def test_gemm_fp8(M, N, K): - kernel = fp8_matmul(M, N, K) + torch.manual_seed(0) + kernel = fp8_matmul(M, N, K) @@ - print("passed~") + print("passed")Optionally add an odd-dimension smoke test under
__main__:if __name__ == "__main__": test_gemm_fp8(512, 512, 512) + test_gemm_fp8(513, 517, 509)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
examples/gemm_fp8/example_tilelang_gemm_amd.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
examples/gemm_fp8/example_tilelang_gemm_amd.py (7)
tilelang/utils/tensor.py (1)
torch_assert_close(220-312)tilelang/autotuner/tuner.py (1)
autotune(692-785)tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/allocate.py (2)
alloc_fragment(53-64)alloc_shared(21-36)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)tilelang/language/copy.py (1)
copy(84-152)
🪛 Ruff (0.12.2)
examples/gemm_fp8/example_tilelang_gemm_amd.py
121-121: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: build-test-nvidia
- GitHub Check: bot-task
🔇 Additional comments (2)
examples/gemm_fp8/example_tilelang_gemm_amd.py (2)
16-24: Verify device selection on ROCm runnersThis example targets AMD; PyTorch on ROCm typically still uses
cudadevice type, but environments differ. Please confirm this runs on your ROCm CI/hosts; otherwise consider parameterizing the device.
70-89: Edge tiles: please confirm OOB safety on non-multiple shapesThe unguarded
T.copyof full tiles relies on TileLang to handle partial tiles safely. Please run a smoke test with non-multiples (e.g., M=513, N=517, K=509) to confirm no OOB or correctness regressions.Also applies to: 96-115
|
Awesome Contribution! Merged :) |
Summary by CodeRabbit
New Features
Performance
Refactor
Documentation