-
Notifications
You must be signed in to change notification settings - Fork 333
Revert "[WIP] support more dtypes for tcgen05 (#1229)" #1323
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
This reverts commit 0d101c1.
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThis PR removes unused metadata flags (enable_ws, enable_2cta) from the TCGEN5 MMA infrastructure, simplifies template-based packing logic (pack16), and replaces CuTe floating-point types with native FP8 types in SM100 GEMM dispatch. Additionally, it removes an outdated FP8 example and consolidates related type mappings. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes
Possibly related PRs
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Tip 📝 Customizable high-level summaries are now available in beta!You can now customize how CodeRabbit generates the high-level summary in your pull requests — including its content, structure, tone, and formatting.
Example instruction:
Note: This feature is currently in beta for Pro-tier users, and pricing will be announced later. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
🧹 Nitpick comments (2)
tilelang/intrinsics/tcgen05_macro_generator.py (2)
171-178: Meta tuple now 3‑element; enable_ws derived from atom_mUpdating the meta length check to 3 and unpacking
(atom_m, atom_n, atom_k)aligns with the newTCGEN5MMAMetadefinition. Derivingenable_wsasatom_m != 128matches the D/E/G layout convention (only the 128‑row atom is non‑WS), so this looks consistent with the C++ meta logic.If you want to quiet TRY003 from Ruff, you could factor the long error message into a helper or shorten it, but that’s purely stylistic.
384-389: Store‑layout meta handling consistent with 3‑field TCGEN5MMAMeta
make_mma_store_layoutnow expects a 3‑element meta and uses onlyatom_mandatom_nto validate tile divisibility, ignoringatom_k. That matches how the store layout is computed (only M×N tiling matters here), and the error message accurately describes the unsupported configuration.Same as above, the long f-string in the
ValueErroris fine functionally; consider shortening or moving it if you want to satisfy TRY003.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (11)
examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py(0 hunks)src/op/copy.cc(2 hunks)src/op/gemm_py.cc(0 hunks)src/op/tcgen5_meta.h(2 hunks)src/tl_templates/cuda/copy_sm100.h(1 hunks)src/tl_templates/cuda/gemm_sm100.h(1 hunks)src/tl_templates/cuda/tcgen_05_ld.h(11 hunks)tilelang/intrinsics/mma_macro_generator.py(0 hunks)tilelang/intrinsics/tcgen05_macro_generator.py(2 hunks)tilelang/jit/adapter/wrapper.py(0 hunks)tilelang/tileop/gemm/gemm_tcgen05.py(1 hunks)
💤 Files with no reviewable changes (4)
- tilelang/jit/adapter/wrapper.py
- tilelang/intrinsics/mma_macro_generator.py
- src/op/gemm_py.cc
- examples/gemm_fp8/example_tilelang_gemm_fp8_sm100.py
🧰 Additional context used
🧬 Code graph analysis (2)
tilelang/tileop/gemm/gemm_tcgen05.py (1)
tilelang/tileop/gemm/gemm_base.py (1)
accum_dtype(59-60)
src/tl_templates/cuda/copy_sm100.h (1)
src/tl_templates/cuda/tcgen_05_ld.h (5)
tl(10-713)tmem_ld_32dp32bNx(13-182)tmem_ld_32dp64bNx(684-691)tmem_ld_32dp128bNx(694-701)tmem_ld_32dp256bNx(704-711)
🪛 Ruff (0.14.5)
tilelang/intrinsics/tcgen05_macro_generator.py
173-175: Avoid specifying long messages outside the exception class
(TRY003)
386-387: Avoid specifying long messages outside the exception class
(TRY003)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (2)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (18)
src/op/copy.cc (2)
1127-1129: ICHECK message reflow is fineThis just improves readability of the unsupported tmem copy diagnostic; no behavioral change.
1249-1250: Confirm tcgen05 intrinsic name matches new template signatureAppending only
<num_chunks_each_wg>tometa.intrinsics_nameassumes that name now refers to the unspecialized function (e.g.tmem_ld_*Nx::copy). With the removal of thePack16template parameter on the loader classes, that looks correct, but it’s worth double‑checking thattcgen05_layout.hsetsintrinsics_nameaccordingly so codegen still emits valid calls.src/op/tcgen5_meta.h (2)
24-28: FAIL/SUCCESS macros correctly updated for 3‑field TCGEN5MMAMetaReturning
{false, TCGEN5MMAMeta{0, 0, 0}}on FAIL and{true, TCGEN5MMAMeta{atom_m, atom_n, atom_k}}on SUCCESS matches the new struct definition and keeps the API straightforward for FFI callers.
30-75: Verification confirms narrow TCGEN5MMA dtype space is intentional; minor mismatch found in _FLOAT8_DTYPES setThe review concerns are validated:
c_dtype=float32 enforcement: Confirmed at line 106 in
gemm_tcgen05.py, which validatesaccum_dtype != "float32"and raises an error. All TCGEN5 operations will enforce this constraint before invokingget_tcgen5_mma_meta.Supported FP8 encodings: No TCGEN5 tests or examples use fnuz variants. All fnuz usage (e.g.,
float8_e4m3fnuz) is confined to AMD MFMA code paths, not TCGEN5. The narrowing to {float8_e4m3fn,float8_e5m2} is consistent with codebase usage.Unsupported dtype handling: Confirmed that unsupported dtypes will trigger
FAILinGetTCGEN5MMAMeta, which returns an empty result, and the caller intcgen05_macro_generator.py(lines 172–175, 386–387) raisesValueErrorwith the configuration details.Minor issue:
_FLOAT8_DTYPESingemm_tcgen05.py(lines 13–16) still lists fnuz variants (float8_e4m3fnuz,float8_e5m2fnuz), but these are not validated or enforced. The set appears stale and should be updated to match the actual supported dtypes for clarity and to prevent future confusion.src/tl_templates/cuda/tcgen_05_ld.h (7)
13-182: Non‑templatedtmem_ld_32dp32bNxwithcopy<N>looks correctThe new class with
template<int N> static copy(...)and theNpower‑of‑two static_assert matches the PTX mnemonics from x1 through x128, and the dst_ptr indexing aligns with the number of output registers per variant. Thetrapfallback for invalid N is also appropriate for debug builds.
185-354:tmem_ld_16dp64bNxconcrete class is consistent with the 16×64b PTX formsThis mirrors the 32dp32b pattern: a single class with
copy<N>handling x1..x128 via inline asm, guarded by the same power‑of‑two static_assert. The register lists and dst_ptr indices line up with the PTX signatures.
356-520:tmem_ld_16dp128bNx: tighter N bound and expanded asm variantsLimiting N to ≤64 for 16×128b loads and explicitly spelling out x1..x64 variants via inline asm is reasonable and keeps register pressure manageable. The static_assert correctly enforces the allowed N values, and the dst_ptr indexing matches the number of 32‑bit outputs each variant emits.
522-680:tmem_ld_16dp256bNx: N≤32 and multi‑register outputs per patternHere N is capped at 32, with each pattern emitting 4·N 32‑bit words. The x1..x32 asm blocks have consistent operand lists and write into the expected dst_ptr ranges; invalid N again traps. This is a clean consolidation compared to the old Pack16‑templated specializations.
684-691:tmem_ld_32dp64bNxwrapper correctly composes two 16‑lane loadsCalling
tmem_ld_16dp64bNx::copy<N>twice, offsetting src_addr by(16 << 16)and dst_ptr by+N, matches the idea of stitching two 16‑lane 64b patterns into a 32‑lane 64b pattern with 2·N outputs.
694-701:tmem_ld_32dp128bNxdst_ptr strideN * 2matches 2·N outputs per halfEach 16×128b load produces 2·N 32‑bit outputs, so offsetting the second call’s dst_ptr by
N * 2ensures the 32‑lane variant writes a contiguous 4·N‑word block. This is consistent with the lower‑level implementation.
704-711:tmem_ld_32dp256bNxcorrectly offsets byN * 4for 4·N outputs per halfFor 16×256b, each load yields 4·N 32‑bit words, so using
dst_ptr + N * 4on the second call gives the expected 8·N total outputs for the 32‑lane pattern. The composition looks sound.tilelang/tileop/gemm/gemm_tcgen05.py (1)
105-107: Verify float32 accumulator restriction and add documentationThe restriction
if accum_dtype != "float32"is unique to GemmTCGEN5; other GEMM variants (MMA, WGMMA, MFMA) accept wider accumulator dtypes without restriction. While the SM100 preview examples show only float32 accumulators, the code lacks explanation of whether this is a hardware limitation or a conservative preview-stage choice.No TCGEN5-specific tests or active call sites using non-float32 accumulators were found, but this absence doesn't confirm the restriction is necessary. Consider:
- Adding an inline comment explaining why float32-only (hardware constraint vs. implementation limitation)
- Verifying with hardware/intrinsic documentation whether other dtypes can eventually be supported
- Adding a test case documenting this constraint for future maintainers
src/tl_templates/cuda/copy_sm100.h (4)
98-104: tcgen05_ld_32dp32bNx wrapper matches simplified tmem loader interfaceDropping the
pack16template parameter and instantiatingtcgen05_ld_corewithtl::tmem_ld_32dp32bNxdirectly is consistent with thetmem_ld_32dp32bNxinterface (copy templated only onN); the recursion and fencing remain unchanged.
107-114: tcgen05_ld_32dp64bNx wrapper correctly tracks core loader signatureThe 64‑byte variant mirrors the 32‑byte path: template parameters are simplified to
<int N, typename dst_t>, andtcgen05_ld_coreis instantiated withtl::tmem_ld_32dp64bNxand the sameMAX_LOGNas before, so behavior is preserved.
116-123: tcgen05_ld_32dp128bNx wrapper remains structurally consistentThis wrapper follows the same pattern as the 32/64‑byte versions, with only the
target_call_clstemplate argument simplified;MAX_LOGNis still 6 and the async fence is unchanged, so load tiling and synchronization semantics are intact.
125-132: tcgen05_ld_32dp256bNx wrapper aligns with core recursion contractThe 256‑byte path now also uses the non‑pack
tl::tmem_ld_32dp256bNxwhile preservingMAX_LOGN == 5and the recursivetcgen05_ld_corestructure; this keeps the segmentation logic and fencing behavior the same as before the pack16 change.src/tl_templates/cuda/gemm_sm100.h (2)
245-264: FP8 e4 dispatch specializations correctly gate shapes and map to F8F6F4 kernelsThe two
DispatchInstruction<fp8_e4_t, fp8_e4_t, float, ...>specializations cleanly separate theM == 128 && K == 32case (non‑.wsSM100_MMA_F8F6F4_SS) from the(M == 64 || M == 32) && K == 32case (.wsSM100_MMA_F8F6F4_WS_SS), matching the constraints encoded in theMMA_Traits<SM100_MMA_F8F6F4*_SS,...>specialization and aligning with the existing F16/BF16 pattern. UsingMMA_Traits<...>as theMMAalias is consistent with the F8F6F4 traits definition here and should integrate withGemmTensorOp’smake_tiled_mmausage.
266-285: FP8 e5 dispatch mirrors e4 path and reuses the same F8F6F4 infrastructureThe
DispatchInstruction<fp8_e5_t, fp8_e5_t, float, ...>specializations mirror the e4 ones: they enforceK == 32, choose betweenSM100_MMA_F8F6F4_SSandSM100_MMA_F8F6F4_WS_SSbased onM, and rely onMMA_Traitsto wire up K‑dimension (32 elements) and layouts. This keeps the FP8‑e5 path consistent with the e4 path and the underlying F8F6F4 WS traits implementation.
Thanks @Hamerlate to help find the issue, this reverts commit 0d101c1.
Summary by CodeRabbit
Refactor
Changes
✏️ Tip: You can customize this high-level summary in your review settings.