Skip to content

Conversation

@Cunxiao2002
Copy link
Contributor

@Cunxiao2002 Cunxiao2002 commented Oct 21, 2025

Summary by CodeRabbit

  • Improvements

    • Added two FP8 reduced-precision float formats and bfloat16→FP8 conversion support to improve numeric compatibility on CUDA backends.
    • Introduced a consistent type-mapping pathway so runtime numeric types are resolved more reliably, reducing mismatch risk and improving precision handling.
  • Chores

    • Minor import and code hygiene cleanup.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 21, 2025

Walkthrough

Added tl:: FP8 wrapper types with explicit __nv_bfloat16 constructors, switched fp8 aliases to those wrappers, added tl::to_cute_type mappings, updated GEMM CUDA headers to resolve A/B types via cute-converted aliases, and removed a duplicate Python future import.

Changes

Cohort / File(s) Summary
FP8 wrappers & alias switch
src/tl_templates/cuda/common.h, src/tl_templates/cuda/cuda_fp8.h
Add tl::float_e4m3_t and tl::float_e5m2_t (derive from cute:: types, default ctor, explicit __nv_bfloat16→float ctor), add to_cute_type specializations, include common.h, and change fp8_e4_t/fp8_e5_t to alias the tl:: wrappers.
GEMM type mapping / selection (CUDA GEMM headers)
src/tl_templates/cuda/gemm_mma.h, src/tl_templates/cuda/gemm_sm90.h, src/tl_templates/cuda/gemm_sm100.h, src/tl_templates/cuda/gemm_sp_sm90.h
Introduce A_type_cute = tl::to_cute_type<A_type_raw>::type and B_type_cute = tl::to_cute_type<B_type_raw>::type. Change A_type/B_type conditionals to test A_type_cute/B_type_cute (select tfloat32_t vs the cute type). One file shows an unexpected else-branch referencing A_type_cute where B_type_cute was likely intended.
Python annotations
tilelang/language/allocate.py
Add from __future__ import annotations and remove a duplicate import; no behavioral change.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant NV as __nv_bfloat16
  participant TL as tl::float_e4m3_t / tl::float_e5m2_t
  participant CUT as cute::float_e4m3_t / cute::float_e5m2_t
  participant FP as fp8 alias (fp8_e4_t / fp8_e5_t)
  participant GEMM as GEMM headers

  NV->>TL: explicit ctor (convert to float, construct TL)
  TL->>CUT: inherits/forwards to cute base
  FP->>TL: fp8 aliases now reference TL wrappers
  TL->>GEMM: tl::to_cute_type maps TL -> CUT type
  GEMM->>GEMM: select A_type/B_type based on *_cute aliases
  note right of GEMM #D3F4FF: conditional: if *_cute is float → tfloat32_t else → cute type
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

  • Review focus:
    • src/tl_templates/cuda/common.h — correctness of constructors and includes (cutlass/float8.h, cutlass/bfloat16.h).
    • GEMM headers (gemm_sm90.h, gemm_sp_sm90.h, gemm_mma.h, gemm_sm100.h) — ensure B_type fallback logic is correct (one file appears to reference A_type_cute in B_type's else-branch).
    • src/tl_templates/cuda/cuda_fp8.h — ensure include ordering and alias changes don't break existing uses.

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • xwhzz
  • tzj-fxz

Poem

🐰

I hop through headers, quick and light,
E4M3 and E5M2 don tl:: tonight.
BFloat whispers float, we wrap and bind,
CUTLASS and GEMM now read the new kind. 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 10.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[BugFix] Correct direct copy from bf16 to fp8" is fully related to the main changes in the changeset. The PR introduces wrapper types float_e4m3_t and float_e5m2_t with explicit conversion constructors from __nv_bfloat16 to float, and updates related GEMM implementations to use these new type conversions. The title accurately summarizes the primary objective: fixing the bf16-to-fp8 conversion issue by providing explicit conversion paths rather than allowing direct copying. The title is concise, clear, and specific enough that a teammate scanning the history would understand this addresses type conversion correctness between bfloat16 and fp8 formats.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6d885a4 and 7f1a507.

📒 Files selected for processing (1)
  • src/tl_templates/cuda/gemm_mma.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tl_templates/cuda/gemm_mma.h

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@Cunxiao2002 Cunxiao2002 marked this pull request as draft October 21, 2025 07:44
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
src/op/copy.cc (1)

302-306: Apply the proposed fix to the scalar-path cast handling in lines 302-306.

The review comment correctly identifies a critical bug: the scalar fast-path in MakeSIMTLoop bypasses the dtype-mismatch handling that the non-scalar path implements. Verification confirms the non-scalar path (lines 327-332) explicitly checks src->dtype != dst->dtype and applies the intermediate float32 cast for BF16→FP8 conversions, whereas the scalar path directly stores without casting.

The proposed fix is correct and aligns with the existing non-scalar implementation pattern. The fix mirrors the approach already used elsewhere: checking for the specific BF16→FP8 combinations and routing through float32 as an intermediate. Additionally, the fix appropriately covers both float8_e4m3 and float8_e5m2, which is more comprehensive than the current non-scalar handling (which only checks e4m3). This is consistent with codegen patterns observed in codegen_cuda.cc where BF16→float8 conversions go through float intermediates.

Add the minimal BF16→FP8 scalar-copy test case to ensure this path is exercised and prevent regression.

🧹 Nitpick comments (3)
src/op/copy.cc (1)

327-334: BF16→FP8: logic OK; fix comment and broaden to e5m2 as well.

  • The wording says “cast dst to fp32” but you actually cast the loaded value. Please fix the comment.
  • Bug likely applies to both FP8 formats; consider handling e5m2 too.

Apply:

-  if (src->dtype != dst->dtype) {
-    // If dst is fp8 and src is bf16, first cast dst to fp32.
-    if (src->dtype.is_bfloat16() && dst->dtype.is_float8_e4m3()) {
-      value = Cast(DataType::Float(32), value);
-    }
-    value = Cast(dst->dtype, value);
-  }
+  if (src->dtype != dst->dtype) {
+    // If dst is FP8 and src is BF16, first cast the loaded value to FP32 to avoid precision/rounding issues.
+    if (src->dtype.is_bfloat16() &&
+        (dst->dtype.is_float8_e4m3() || dst->dtype.is_float8_e5m2())) {
+      value = Cast(DataType::Float(32), value);
+    }
+    value = Cast(dst->dtype, value);
+  }

Add/extend a unit test to compile BF16→FP8 E5M2 as well to confirm identical fix applies.

testing/python/issue/test_tilelang_issue_1046.py (2)

10-24: Turn this into an assertion-based test and cover FP8 e5m2.

  • Avoid print-only behavior; assert that JIT compilation succeeds and that generated TIR contains the BF16→FP32→FP8 cast chain.
  • Parameterize to also exercise out_dtype="float8_e5m2".

Example:

-@tilelang.jit
-def test_kernel(N, in_dtype=BF16, out_dtype=FP8):
+@tilelang.jit
+def test_kernel(N, in_dtype=BF16, out_dtype=FP8):
@@
-    return test_kernel_
+    return test_kernel_
@@
-kernel = test_kernel(128)
-
-print(kernel.get_kernel_source())
+def _contains_cast_chain(f):
+    src = f.get_kernel_source()
+    # Heuristic: both cast to f32 and cast to fp8 should appear
+    return ("cast(float32" in src or "ConvertF32" in src) and ("e4m3" in src or "e5m2" in src)
+
+def test_issue_1046_bf16_to_fp8_e4m3_compiles():
+    k = test_kernel(128, in_dtype=BF16, out_dtype=FP8)
+    assert _contains_cast_chain(k)
+
+def test_issue_1046_bf16_to_fp8_e5m2_compiles():
+    k = test_kernel(128, in_dtype=BF16, out_dtype="float8_e5m2")
+    assert _contains_cast_chain(k)

If CI environments differ, relax the heuristic to simply ensure compilation returns a kernel object without exceptions.


4-4: Global cache toggle in tests can have side effects.

tilelang.disable_cache() affects the global state and may impact other tests. If not required, drop it or confine under a local setup/teardown.

Wrap in a fixture or context; or remove if unnecessary for reproducing the issue.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 42c267e and cb907dd.

📒 Files selected for processing (2)
  • src/op/copy.cc (1 hunks)
  • testing/python/issue/test_tilelang_issue_1046.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_1046.py (6)
tilelang/env.py (1)
  • disable_cache (271-272)
tilelang/jit/__init__.py (1)
  • jit (244-317)
tilelang/language/symbolics.py (1)
  • dynamic (11-22)
tilelang/language/kernel.py (1)
  • threads (215-219)
tilelang/language/allocate.py (1)
  • alloc_shared (21-36)
tilelang/language/copy.py (1)
  • copy (10-86)

@Cunxiao2002 Cunxiao2002 marked this pull request as ready for review October 22, 2025 03:20
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cb907dd and 8281b05.

📒 Files selected for processing (2)
  • src/op/copy.cc (1 hunks)
  • src/target/codegen_cuda.cc (1 hunks)
✅ Files skipped from review due to trivial changes (1)
  • src/op/copy.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (2)
  • PrintType (186-429)
  • PrintType (186-186)

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 999e74e and 5c25147.

📒 Files selected for processing (3)
  • src/target/codegen_cuda.cc (1 hunks)
  • src/tl_templates/cuda/common.h (2 hunks)
  • src/tl_templates/cuda/cuda_fp8.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/target/codegen_cuda.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/cuda_fp8.h (1)
src/tl_templates/cuda/common.h (1)
  • tl (174-268)
🔇 Additional comments (3)
src/tl_templates/cuda/cuda_fp8.h (2)

5-5: LGTM!

The include of common.h is necessary for the new tl::float_e4m3_t and tl::float_e5m2_t types used in the updated aliases below.


7-8: Type alias changes are correct and safe.

Verification confirms:

  • Wrapper types tl::float_e4m3_t and tl::float_e5m2_t are properly defined in common.h with explicit __nv_bfloat16 conversion constructors
  • No direct references to old cute:: fp8 types exist outside the wrapper definitions
  • Type aliases are actively used throughout the codebase (gemm templates, debug utilities, etc.) and will work correctly with the new wrapper types due to their inheritance from cutlass types
src/tl_templates/cuda/common.h (1)

13-14: LGTM!

The includes for cutlass/float8.h and cutlass/bfloat16.h are necessary for the base types cutlass::float_e4m3_t and cutlass::float_e5m2_t used in the new wrapper structs below.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (1)
src/tl_templates/cuda/common.h (1)

324-324: Comment inconsistency already flagged.

The comment states "implicit conversion" but the constructors are marked explicit, providing explicit conversion. This discrepancy was already noted in the previous review.

🧹 Nitpick comments (1)
src/tl_templates/cuda/common.h (1)

331-332: Consider using bfloat16_t for parameter consistency.

The constructors at lines 331 and 341 accept __nv_bfloat16 as a parameter, but the file already imports cutlass::bfloat16_t (recommended for use whenever possible). CUTLASS's bfloat16_t has an explicit constructor accepting __nv_bfloat16, so changing the parameter type from __nv_bfloat16 to bfloat16_t (or just bfloat16_t since it's aliased on line 16) would maintain compatibility while improving consistency with the codebase's established pattern.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 5c25147 and cda16d5.

📒 Files selected for processing (2)
  • src/tl_templates/cuda/common.h (2 hunks)
  • src/tl_templates/cuda/cuda_fp8.h (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/tl_templates/cuda/cuda_fp8.h
🔇 Additional comments (2)
src/tl_templates/cuda/common.h (2)

13-14: LGTM! Includes support new FP8 wrapper types.

The new includes provide the necessary CUTLASS FP8 and bfloat16 types for the wrapper structs defined below.


335-343: Consistent implementation with float_e4m3_t.

The float_e5m2_t struct correctly mirrors the pattern of float_e4m3_t, providing the same conversion capabilities for the E5M2 format. The same verification regarding __nv_bfloat16 vs cutlass::bfloat16_t applies here as well.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/allocate.py (1)

22-22: Fix type annotation and remove unused import

The code at line 71 uses Union[PrimExpr] = None, which is incorrect (Union requires at least 2 types and doesn't include None). The # noqa: UP007 comment is suppressing the lint warning.

Given the project supports Python >=3.8, Optional is the correct fix for full compatibility (the | syntax only works in 3.10+):

-from typing import Union
+from typing import Optional
@@
-def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None):  # noqa: UP007
+def alloc_var(dtype, *args, scope="local.var", init: Optional[PrimExpr] = None):

This also applies to line 71 (the function definition).

🧹 Nitpick comments (1)
tilelang/language/allocate.py (1)

147-151: Fix alloc_tmem docstring params (uses shape, dtype, not num_cols)

Docstring mentions num_cols but the function signature is (shape, dtype) and asserts a 2D shape. This can mislead users.

Apply this doc-only diff:

-    Args:
-        num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512.
-
-    Returns:
-        T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations.
+    Args:
+        shape (tuple[int, int]): 2D shape (rows, cols) for the TMEM buffer. The column count should be a power of 2 and >= 32 and <= 512.
+        dtype (str): Data type of the buffer (e.g., 'float32', 'int32').
+
+    Returns:
+        T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations.
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between cda16d5 and 0aad651.

📒 Files selected for processing (1)
  • tilelang/language/allocate.py (2 hunks)
🧰 Additional context used
🪛 Ruff (0.14.1)
tilelang/language/allocate.py

71-71: Unused noqa directive (non-enabled: UP007)

Remove unused noqa directive

(RUF100)

🔇 Additional comments (1)
tilelang/language/allocate.py (1)

1-1: Future annotations import looks good

Safe, no behavior change; keeps annotations lazy and avoids runtime imports.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

♻️ Duplicate comments (1)
src/tl_templates/cuda/common.h (1)

324-343: Comment says “implicit”, code enforces explicit conversion. Please align.

Constructors are marked explicit; update the comment to avoid confusion. This was already raised earlier.

Apply:

-// and add the desired implicit conversion from bfloat16_t.
+// and add explicit conversion from __nv_bfloat16 via float.
🧹 Nitpick comments (1)
src/tl_templates/cuda/gemm_sm90.h (1)

18-26: Deduplicate to_cute_type into a shared header.

The same mapper exists in gemm_mma.h; move it to a common header (e.g., common.h) to avoid divergence.

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0aad651 and dad541c.

📒 Files selected for processing (3)
  • src/tl_templates/cuda/common.h (2 hunks)
  • src/tl_templates/cuda/gemm_mma.h (1 hunks)
  • src/tl_templates/cuda/gemm_sm90.h (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/tl_templates/cuda/common.h (2)
src/tl_templates/cuda/gemm_mma.h (2)
  • cute (14-22)
  • cute (108-457)
src/tl_templates/cuda/gemm_sm90.h (1)
  • cute (11-158)
src/tl_templates/cuda/gemm_sm90.h (1)
src/tl_templates/cuda/gemm_mma.h (5)
  • tl (263-265)
  • tl (266-268)
  • tl (459-494)
  • cute (14-22)
  • cute (108-457)
src/tl_templates/cuda/gemm_mma.h (1)
src/tl_templates/cuda/gemm_sm90.h (4)
  • tl (21-23)
  • tl (24-26)
  • tl (242-396)
  • cute (11-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
  • GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (1)
src/tl_templates/cuda/gemm_mma.h (1)

260-268: Mapper looks good.

Identity primary and the two tl:: specializations are correct.

Comment on lines +13 to +15
#include <cutlass/bfloat16.h>
#include <cutlass/float8.h>

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Explicitly include cuda_bf16.h to guarantee __nv_bfloat16 availability.

Avoid relying on transitive includes; add the CUDA header so host/RTC builds consistently see __nv_bfloat16.

Apply this diff near the existing cuda_runtime include:

 #ifndef __CUDACC_RTC__
 #include <cuda_runtime.h>
+#include <cuda_bf16.h>
 #endif

Committable suggestion skipped: line range outside the PR's diff.

🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 13 to 15, the file relies on
transitive includes for the CUDA bfloat16 type (__nv_bfloat16); explicitly add
the CUDA header cuda_bf16.h (near the existing cuda_runtime include) to
guarantee __nv_bfloat16 is available for host and RTC builds, avoiding
transitive-include fragility.

@LeiWang1999
Copy link
Member

We're good to go if we can resolve this conflict

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants