-
Couldn't load subscription status.
- Fork 279
[BugFix] Correct direct copy from bf16 to fp8 #1090
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
WalkthroughAdded tl:: FP8 wrapper types with explicit __nv_bfloat16 constructors, switched fp8 aliases to those wrappers, added tl::to_cute_type mappings, updated GEMM CUDA headers to resolve A/B types via cute-converted aliases, and removed a duplicate Python future import. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant NV as __nv_bfloat16
participant TL as tl::float_e4m3_t / tl::float_e5m2_t
participant CUT as cute::float_e4m3_t / cute::float_e5m2_t
participant FP as fp8 alias (fp8_e4_t / fp8_e5_t)
participant GEMM as GEMM headers
NV->>TL: explicit ctor (convert to float, construct TL)
TL->>CUT: inherits/forwards to cute base
FP->>TL: fp8 aliases now reference TL wrappers
TL->>GEMM: tl::to_cute_type maps TL -> CUT type
GEMM->>GEMM: select A_type/B_type based on *_cute aliases
note right of GEMM #D3F4FF: conditional: if *_cute is float → tfloat32_t else → cute type
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes
Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
🚧 Files skipped from review as they are similar to previous changes (1)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
src/op/copy.cc (1)
302-306: Apply the proposed fix to the scalar-path cast handling in lines 302-306.The review comment correctly identifies a critical bug: the scalar fast-path in
MakeSIMTLoopbypasses the dtype-mismatch handling that the non-scalar path implements. Verification confirms the non-scalar path (lines 327-332) explicitly checkssrc->dtype != dst->dtypeand applies the intermediate float32 cast for BF16→FP8 conversions, whereas the scalar path directly stores without casting.The proposed fix is correct and aligns with the existing non-scalar implementation pattern. The fix mirrors the approach already used elsewhere: checking for the specific BF16→FP8 combinations and routing through float32 as an intermediate. Additionally, the fix appropriately covers both
float8_e4m3andfloat8_e5m2, which is more comprehensive than the current non-scalar handling (which only checks e4m3). This is consistent with codegen patterns observed incodegen_cuda.ccwhere BF16→float8 conversions go through float intermediates.Add the minimal BF16→FP8 scalar-copy test case to ensure this path is exercised and prevent regression.
🧹 Nitpick comments (3)
src/op/copy.cc (1)
327-334: BF16→FP8: logic OK; fix comment and broaden to e5m2 as well.
- The wording says “cast dst to fp32” but you actually cast the loaded value. Please fix the comment.
- Bug likely applies to both FP8 formats; consider handling e5m2 too.
Apply:
- if (src->dtype != dst->dtype) { - // If dst is fp8 and src is bf16, first cast dst to fp32. - if (src->dtype.is_bfloat16() && dst->dtype.is_float8_e4m3()) { - value = Cast(DataType::Float(32), value); - } - value = Cast(dst->dtype, value); - } + if (src->dtype != dst->dtype) { + // If dst is FP8 and src is BF16, first cast the loaded value to FP32 to avoid precision/rounding issues. + if (src->dtype.is_bfloat16() && + (dst->dtype.is_float8_e4m3() || dst->dtype.is_float8_e5m2())) { + value = Cast(DataType::Float(32), value); + } + value = Cast(dst->dtype, value); + }Add/extend a unit test to compile BF16→FP8 E5M2 as well to confirm identical fix applies.
testing/python/issue/test_tilelang_issue_1046.py (2)
10-24: Turn this into an assertion-based test and cover FP8 e5m2.
- Avoid print-only behavior; assert that JIT compilation succeeds and that generated TIR contains the BF16→FP32→FP8 cast chain.
- Parameterize to also exercise out_dtype="float8_e5m2".
Example:
-@tilelang.jit -def test_kernel(N, in_dtype=BF16, out_dtype=FP8): +@tilelang.jit +def test_kernel(N, in_dtype=BF16, out_dtype=FP8): @@ - return test_kernel_ + return test_kernel_ @@ -kernel = test_kernel(128) - -print(kernel.get_kernel_source()) +def _contains_cast_chain(f): + src = f.get_kernel_source() + # Heuristic: both cast to f32 and cast to fp8 should appear + return ("cast(float32" in src or "ConvertF32" in src) and ("e4m3" in src or "e5m2" in src) + +def test_issue_1046_bf16_to_fp8_e4m3_compiles(): + k = test_kernel(128, in_dtype=BF16, out_dtype=FP8) + assert _contains_cast_chain(k) + +def test_issue_1046_bf16_to_fp8_e5m2_compiles(): + k = test_kernel(128, in_dtype=BF16, out_dtype="float8_e5m2") + assert _contains_cast_chain(k)If CI environments differ, relax the heuristic to simply ensure compilation returns a kernel object without exceptions.
4-4: Global cache toggle in tests can have side effects.tilelang.disable_cache() affects the global state and may impact other tests. If not required, drop it or confine under a local setup/teardown.
Wrap in a fixture or context; or remove if unnecessary for reproducing the issue.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/copy.cc(1 hunks)testing/python/issue/test_tilelang_issue_1046.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/issue/test_tilelang_issue_1046.py (6)
tilelang/env.py (1)
disable_cache(271-272)tilelang/jit/__init__.py (1)
jit(244-317)tilelang/language/symbolics.py (1)
dynamic(11-22)tilelang/language/kernel.py (1)
threads(215-219)tilelang/language/allocate.py (1)
alloc_shared(21-36)tilelang/language/copy.py (1)
copy(10-86)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/op/copy.cc(1 hunks)src/target/codegen_cuda.cc(1 hunks)
✅ Files skipped from review due to trivial changes (1)
- src/op/copy.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/target/codegen_cuda.cc (1)
src/target/codegen_hip.cc (2)
PrintType(186-429)PrintType(186-186)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/target/codegen_cuda.cc(1 hunks)src/tl_templates/cuda/common.h(2 hunks)src/tl_templates/cuda/cuda_fp8.h(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/target/codegen_cuda.cc
🧰 Additional context used
🧬 Code graph analysis (1)
src/tl_templates/cuda/cuda_fp8.h (1)
src/tl_templates/cuda/common.h (1)
tl(174-268)
🔇 Additional comments (3)
src/tl_templates/cuda/cuda_fp8.h (2)
5-5: LGTM!The include of
common.his necessary for the newtl::float_e4m3_tandtl::float_e5m2_ttypes used in the updated aliases below.
7-8: Type alias changes are correct and safe.Verification confirms:
- Wrapper types
tl::float_e4m3_tandtl::float_e5m2_tare properly defined incommon.hwith explicit__nv_bfloat16conversion constructors- No direct references to old
cute::fp8 types exist outside the wrapper definitions- Type aliases are actively used throughout the codebase (gemm templates, debug utilities, etc.) and will work correctly with the new wrapper types due to their inheritance from cutlass types
src/tl_templates/cuda/common.h (1)
13-14: LGTM!The includes for
cutlass/float8.handcutlass/bfloat16.hare necessary for the base typescutlass::float_e4m3_tandcutlass::float_e5m2_tused in the new wrapper structs below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (1)
src/tl_templates/cuda/common.h (1)
324-324: Comment inconsistency already flagged.The comment states "implicit conversion" but the constructors are marked
explicit, providing explicit conversion. This discrepancy was already noted in the previous review.
🧹 Nitpick comments (1)
src/tl_templates/cuda/common.h (1)
331-332: Consider usingbfloat16_tfor parameter consistency.The constructors at lines 331 and 341 accept
__nv_bfloat16as a parameter, but the file already importscutlass::bfloat16_t(recommended for use whenever possible). CUTLASS'sbfloat16_thas an explicit constructor accepting__nv_bfloat16, so changing the parameter type from__nv_bfloat16tobfloat16_t(or justbfloat16_tsince it's aliased on line 16) would maintain compatibility while improving consistency with the codebase's established pattern.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
src/tl_templates/cuda/common.h(2 hunks)src/tl_templates/cuda/cuda_fp8.h(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/tl_templates/cuda/cuda_fp8.h
🔇 Additional comments (2)
src/tl_templates/cuda/common.h (2)
13-14: LGTM! Includes support new FP8 wrapper types.The new includes provide the necessary CUTLASS FP8 and bfloat16 types for the wrapper structs defined below.
335-343: Consistent implementation withfloat_e4m3_t.The
float_e5m2_tstruct correctly mirrors the pattern offloat_e4m3_t, providing the same conversion capabilities for the E5M2 format. The same verification regarding__nv_bfloat16vscutlass::bfloat16_tapplies here as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
tilelang/language/allocate.py (1)
22-22: Fix type annotation and remove unused importThe code at line 71 uses
Union[PrimExpr] = None, which is incorrect (Union requires at least 2 types and doesn't include None). The# noqa: UP007comment is suppressing the lint warning.Given the project supports Python >=3.8,
Optionalis the correct fix for full compatibility (the|syntax only works in 3.10+):-from typing import Union +from typing import Optional @@ -def alloc_var(dtype, *args, scope="local.var", init: Union[PrimExpr] = None): # noqa: UP007 +def alloc_var(dtype, *args, scope="local.var", init: Optional[PrimExpr] = None):This also applies to line 71 (the function definition).
🧹 Nitpick comments (1)
tilelang/language/allocate.py (1)
147-151: Fixalloc_tmemdocstring params (usesshape, dtype, notnum_cols)Docstring mentions
num_colsbut the function signature is(shape, dtype)and asserts a 2D shape. This can mislead users.Apply this doc-only diff:
- Args: - num_cols (int): Number of columns to allocate in TMEM. Must be a power of 2 and >= 32 but less than or equal to 512. - - Returns: - T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations. + Args: + shape (tuple[int, int]): 2D shape (rows, cols) for the TMEM buffer. The column count should be a power of 2 and >= 32 and <= 512. + dtype (str): Data type of the buffer (e.g., 'float32', 'int32'). + + Returns: + T.Buffer: A TVM buffer object allocated in TMEM scope, suitable for use as an accumulator or operand in TCGEN5.MMA operations.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
tilelang/language/allocate.py(2 hunks)
🧰 Additional context used
🪛 Ruff (0.14.1)
tilelang/language/allocate.py
71-71: Unused noqa directive (non-enabled: UP007)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (1)
tilelang/language/allocate.py (1)
1-1: Future annotations import looks goodSafe, no behavior change; keeps annotations lazy and avoids runtime imports.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
♻️ Duplicate comments (1)
src/tl_templates/cuda/common.h (1)
324-343: Comment says “implicit”, code enforces explicit conversion. Please align.Constructors are marked explicit; update the comment to avoid confusion. This was already raised earlier.
Apply:
-// and add the desired implicit conversion from bfloat16_t. +// and add explicit conversion from __nv_bfloat16 via float.
🧹 Nitpick comments (1)
src/tl_templates/cuda/gemm_sm90.h (1)
18-26: Deduplicate to_cute_type into a shared header.The same mapper exists in gemm_mma.h; move it to a common header (e.g., common.h) to avoid divergence.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
src/tl_templates/cuda/common.h(2 hunks)src/tl_templates/cuda/gemm_mma.h(1 hunks)src/tl_templates/cuda/gemm_sm90.h(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
src/tl_templates/cuda/common.h (2)
src/tl_templates/cuda/gemm_mma.h (2)
cute(14-22)cute(108-457)src/tl_templates/cuda/gemm_sm90.h (1)
cute(11-158)
src/tl_templates/cuda/gemm_sm90.h (1)
src/tl_templates/cuda/gemm_mma.h (5)
tl(263-265)tl(266-268)tl(459-494)cute(14-22)cute(108-457)
src/tl_templates/cuda/gemm_mma.h (1)
src/tl_templates/cuda/gemm_sm90.h (4)
tl(21-23)tl(24-26)tl(242-396)cute(11-158)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
- GitHub Check: Test for Python 3.12 with CUDA-12.8 (on self-hosted-nvidia)
- GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (1)
src/tl_templates/cuda/gemm_mma.h (1)
260-268: Mapper looks good.Identity primary and the two tl:: specializations are correct.
| #include <cutlass/bfloat16.h> | ||
| #include <cutlass/float8.h> | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Explicitly include cuda_bf16.h to guarantee __nv_bfloat16 availability.
Avoid relying on transitive includes; add the CUDA header so host/RTC builds consistently see __nv_bfloat16.
Apply this diff near the existing cuda_runtime include:
#ifndef __CUDACC_RTC__
#include <cuda_runtime.h>
+#include <cuda_bf16.h>
#endifCommittable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In src/tl_templates/cuda/common.h around lines 13 to 15, the file relies on
transitive includes for the CUDA bfloat16 type (__nv_bfloat16); explicitly add
the CUDA header cuda_bf16.h (near the existing cuda_runtime include) to
guarantee __nv_bfloat16 is available for host and RTC builds, avoiding
transitive-include fragility.
|
We're good to go if we can resolve this conflict |
Summary by CodeRabbit
Improvements
Chores