Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Oct 21, 2025

Finished #1034

Add features:

  • Vectorized convertion between fp16/bf16 and fp32
  • Vectorized convertion from fp32 to fp8 (both e4m3 and e5m2)

Summary by CodeRabbit

  • New Features

    • Faster, more reliable CUDA vectorized casts between float16/float32/float8/bfloat16 with prioritized 2- and 4-lane vector paths while retaining per-element fallback.
  • Tests

    • New CUDA test suite validating vectorized cast correctness across multiple source/destination dtypes and lane configurations.
  • Chores / Style

    • Removed explicit compile flags from several decorator configurations and consolidated decorator formatting in example scripts.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 21, 2025

Walkthrough

Reworks CUDA CastNode codegen to add explicit 2- and 4-lane vectorized conversion branches for float16↔float32, bfloat16↔float32, and float32↔FP8 (E4M3/E5M2) using CUDA intrinsics; removes prior bf16-op path. Adds tests for vectorized casts and removes compile_flags from several example decorators.

Changes

Cohort / File(s) Summary
CUDA cast vectorization
src/target/codegen_cuda.cc
Reworked CastNode codegen to add explicit vectorized 2-lane and 4-lane branches for float16↔float32 (__half22float2, __float22half2_rn), bfloat16↔float32 (__bfloat1622float2, __float22bfloat162_rn), and float32↔FP8 (E4M3/E5M2) (__nv_cvt_float2_to_fp8x2 and variants). Vectorized paths write into sret and return early. Removed previous ENABLE_BF16-guarded bf16-op path. Fallback: elementwise cast loop remains.
Tests — vectorized cast validations
testing/python/language/test_tilelang_language_vectorized_cast.py
New test module adding str2dtype, a vectorized_cast_kernel(M, dtype_A, dtype_B) kernel factory, run_vectorized_cast(...) runner, and test_vectorized_cast() exercising multiple dtype pairs and lane widths; allocates CUDA tensors, runs kernels, asserts results, and checks kernel source for vectorization hints.
Examples — decorator cleanup
examples/attention_sink/example_gqa_sink_bwd_bhsd.py, examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py, examples/attention_sink/example_mha_sink_bwd_bhsd.py, examples/attention_sink/example_mha_sink_fwd_bhsd.py, examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
Removed compile_flags argument from multiple @tilelang.jit decorators and condensed pass_configs formatting. No changes to function signatures or runtime logic beyond decorator argument formatting.

Sequence Diagram(s)

sequenceDiagram
  participant IR as CastNode (codegen)
  participant Detect as Type/Width Detector
  participant Emit as CUDA Emitter
  participant Fallback as Elementwise Emitter
  rect rgb(245,250,255)
  Note over IR,Detect: Prefer 2-/4-lane vectorized paths when applicable
  end
  IR->>Detect: inspect src/dst types and lane count
  alt float16↔float32 (2/4 lanes)
    Detect-->>IR: match half↔float vector
    IR->>Emit: emit `__half22float2` / 4-lane equivalent or `__float22half2_rn`
    Emit-->>IR: write vectorized store to sret
    IR->>IR: return early
  else bfloat16↔float32 (2/4 lanes)
    Detect-->>IR: match bf16↔float vector
    IR->>Emit: emit `__bfloat1622float2` / `__float22bfloat162_rn`
    Emit-->>IR: write vectorized store to sret
    IR->>IR: return early
  else float32↔FP8 (2/4 lanes)
    Detect-->>IR: match float↔FP8 vector
    IR->>Emit: emit `__nv_cvt_float2_to_fp8x2` / 4-lane equivalents
    Emit-->>IR: write vectorized store to sret
    IR->>IR: return early
  else no vectorized path
    Detect-->>IR: no vectorized match
    IR->>Fallback: emit per-element cast loop
    Fallback-->>IR: elementwise stores emitted
  end
Loading

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~25 minutes

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

🐇
Two lanes hop, then four leap wide,
Intrinsics hum where kernels glide,
Half, bf, and FP8 take flight,
I bound through casts beneath the light,
Carrots, code, and CUDA pride 🥕

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 6.25% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The pull request title "[Feature] Enhance vectorized conversion support in CUDA codegen" directly aligns with the main changes in the repository. The primary modification is in src/target/codegen_cuda.cc, which reworks the CastNode codegen to add explicit vectorized conversion branches for multiple data type pairs (float16↔float32, BF16↔float32, float32↔FP8), supporting both 2-lane and 4-lane vectorization paths. The title accurately captures this core enhancement and is specific enough for a teammate to understand the primary intent without being vague or generic. While the changeset includes secondary modifications (removal of compile_flags from decorators in example files), the title appropriately focuses on the main feature, which is appropriate as it's not expected to cover every detail.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 9f40071 and 0420742.

📒 Files selected for processing (7)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py (4 hunks)
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py (4 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd.py (1 hunks)
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py (1 hunks)
  • src/target/codegen_cuda.cc (1 hunks)
  • testing/python/language/test_tilelang_language_vectorized_cast.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (4)
  • examples/attention_sink/example_gqa_sink_bwd_bhsd.py
  • examples/attention_sink/example_mha_sink_fwd_bhsd_wgmma_pipelined.py
  • examples/attention_sink/example_gqa_sink_fwd_bhsd_wgmma_pipelined.py
  • examples/attention_sink/example_mha_sink_bwd_bhsd.py
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (2)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/copy.py (1)
  • copy (10-86)
🪛 Ruff (0.14.1)
testing/python/language/test_tilelang_language_vectorized_cast.py

20-20: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


21-21: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with ROCm-6.3 (on self-hosted-amd)
🔇 Additional comments (5)
examples/attention_sink/example_mha_sink_fwd_bhsd.py (1)

21-23: No action required—compile_flags removal from this example is intentional and safe.

The search confirms compile_flags remains fully supported throughout the tilelang library (tilelang/jit/init.py, adapters, kernel caching). This example simply chose to use the pass_configs approach instead, which is an equivalent and valid refactoring. Other examples (deepseek_mla, deepseek_v32, compile_flags/usecase.py) continue using compile_flags without issues, demonstrating backward compatibility. This is not a breaking change.

src/target/codegen_cuda.cc (2)

903-943: LGTM! Float16 ↔ float32 vectorized conversions are correctly implemented.

The 2-lane and 4-lane vectorized conversion paths using __half22float2 and __float22half2_rn intrinsics are correct. The early returns ensure proper control flow, and the type system constraints adequately prevent misuse without requiring explicit bit-width checks.


945-989: LGTM! BFloat16 ↔ float32 vectorized conversions are correctly implemented.

The vectorized conversion paths for bfloat16 ↔ float32 follow the same correct pattern as float16, using appropriate __bfloat1622float2 and __float22bfloat162_rn intrinsics with proper type casting.

testing/python/language/test_tilelang_language_vectorized_cast.py (2)

5-11: LGTM! Clean dtype mapping.

The str2dtype dictionary correctly maps string representations to PyTorch dtypes for all tested conversion types.


54-77: Good test coverage for vectorized casts.

The test comprehensively covers all dtype pairs (float16, bfloat16, float8_e4m3, float8_e5m2) with both 2-lane and 4-lane vectorization, verifying both correctness and that vectorized intrinsics are actually emitted. This provides strong validation of the CUDA codegen changes.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bddb125 and 6debbb9.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc (1 hunks)

LeiWang1999
LeiWang1999 previously approved these changes Oct 21, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6debbb9 and d44ee95.

📒 Files selected for processing (1)
  • src/target/codegen_cuda.cc (2 hunks)
🔇 Additional comments (3)
src/target/codegen_cuda.cc (3)

903-911: Vectorized float16→float32 conversion looks correct.

The implementation properly uses the __half22float2 intrinsic for 2-lane conversions and returns early to avoid the elementwise fallback. The pointer casting pattern *(half2*)(&(src)) is standard CUDA practice for vector type reinterpretation.


912-921: Vectorized float32→float16 conversion looks correct.

The implementation properly uses the __float22half2_rn intrinsic (round-to-nearest mode) for the reverse conversion. The bidirectional support for float16↔float32 is complete and symmetric.


986-986: Helpful clarification comment.

This comment appropriately documents why BF16 handling differs from the FP16/FP8 vectorized paths. It clarifies that BF16 intentionally falls through to the elementwise cast fallback rather than returning early, which helps maintainers understand the design decision.

LeiWang1999
LeiWang1999 previously approved these changes Oct 21, 2025
@Rachmanino Rachmanino changed the title [Feature] Add vectorized float16 and float32 conversion support in CUDA codegen [Feature] Enhance vectorized conversion support in CUDA codegen Oct 21, 2025
@LeiWang1999
Copy link
Member

We're good to go if we can add a simple unit codegen test for this case.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (3)
testing/python/language/test_tilelang_language_vectorized_cast.py (3)

53-54: Compare in fp32 to avoid dtype-specific assert quirks (esp. float8/bf16).

Casts involving f8/bf16 are most stable when compared after upcasting to fp32.

-    torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B)
+    torch.testing.assert_close(
+        A.to(str2dtype[dst_dtype_str]).to(torch.float32),
+        B.to(torch.float32)
+    )

55-59: Make codegen token check robust across backends/versions.

Allow multiple acceptable substrings and assert any-match; also improve failure message.

-    code = kernel.get_kernel_source()
-
-    assert check_str in code, \
-        f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"
+    code = kernel.get_kernel_source()
+    tokens = (check_str,) if isinstance(check_str, str) else tuple(check_str)
+    assert any(t in code for t in tokens), (
+        f"Cast {src_dtype_str}->{dst_dtype_str} lanes={lanes} not vectorized. "
+        f"Expected one of: {tokens}"
+    )

If needed, pass a tuple for bf16 (e.g., (“fastertransformer”, “cvt.rn.bf16x2”, “ld.global.v2.b16”)) to tolerate NVRTC/PTX differences.


70-77: Optional: stabilize float8/bf16 inputs.

To reduce rare saturation/rounding edge cases, consider clamping inputs for f8 paths (e.g., to [-16, 16]) or seeding RNG.

Also applies to: 81-86

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 0dd40d4 and 50d4d5d.

📒 Files selected for processing (2)
  • src/target/codegen_cuda.cc (2 hunks)
  • testing/python/language/test_tilelang_language_vectorized_cast.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • src/target/codegen_cuda.cc
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (2)
tilelang/jit/__init__.py (1)
  • jit (237-310)
tilelang/language/copy.py (1)
  • copy (10-86)
🪛 Ruff (0.14.1)
testing/python/language/test_tilelang_language_vectorized_cast.py

21-21: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


22-22: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)

🔇 Additional comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)

24-26: LGTM: minimal kernel to exercise vectorized copy.

The single-block, 128-thread kernel with T.copy is a clean trigger for the vectorized codegen paths.

Comment on lines 21 to 22
A: T.Tensor[(M), dtype_A], # noqa: F821
B: T.Tensor[(M), dtype_B], # noqa: F821
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Remove unused noqa pragmas.

# noqa: F821 is unnecessary per Ruff; drop both to keep lint clean.

As per static analysis hints.

-        A: T.Tensor[(M), dtype_A],  # noqa: F821
-        B: T.Tensor[(M), dtype_B],  # noqa: F821
+        A: T.Tensor[(M), dtype_A],
+        B: T.Tensor[(M), dtype_B],
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
A: T.Tensor[(M), dtype_A], # noqa: F821
B: T.Tensor[(M), dtype_B], # noqa: F821
A: T.Tensor[(M), dtype_A],
B: T.Tensor[(M), dtype_B],
🧰 Tools
🪛 Ruff (0.14.1)

21-21: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)


22-22: Unused noqa directive (unused: F821)

Remove unused noqa directive

(RUF100)

🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around
lines 21 to 22, the two inline "# noqa: F821" pragmas on the A and B Tensor type
lines are unnecessary; remove the "# noqa: F821" suffix from both lines so the
linter no longer shows unused noqa pragmas and the file remains lint-clean.

"""Run the vectorized cast kernel and check the correctness.
Args:
src_dtype_str: The source data type string.
dst_dtype_str: The destination data type string.
check_str: Used to ensure vectorized cast is used.
M: The size of the tensor.
lanes: The number of lanes of the source and destination data types.
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix docstring: remove nonexistent “M” arg.

“M” isn’t a parameter; it’s derived. Clarify to avoid confusion.

-        M: The size of the tensor.
-        lanes: The number of lanes of the source and destination data types.
+        lanes: The number of lanes of the source and destination data types.
+        Note: M is computed as 128 * lanes to satisfy kernel constraints.
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
"""Run the vectorized cast kernel and check the correctness.
Args:
src_dtype_str: The source data type string.
dst_dtype_str: The destination data type string.
check_str: Used to ensure vectorized cast is used.
M: The size of the tensor.
lanes: The number of lanes of the source and destination data types.
"""
"""Run the vectorized cast kernel and check the correctness.
Args:
src_dtype_str: The source data type string.
dst_dtype_str: The destination data type string.
check_str: Used to ensure vectorized cast is used.
lanes: The number of lanes of the source and destination data types.
Note: M is computed as 128 * lanes to satisfy kernel constraints.
"""
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around
lines 36 to 43, the docstring incorrectly lists an argument "M" which does not
exist; remove "M" from the args list and update the description to state that
the tensor size is derived (or computed internally) instead of being an input
parameter, keeping the rest of the docstring intact and consistent with actual
function parameters.

f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!"


def test_vectorized_cast():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Skip on CPU-only runners.

Guard the test to avoid CI failures when CUDA isn’t available.

+import pytest
@@
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
 def test_vectorized_cast():
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def test_vectorized_cast():
import pytest
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
def test_vectorized_cast():
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around line
61, the test function test_vectorized_cast must be skipped on CPU-only runners;
add a guard using pytest to skip when CUDA is unavailable (e.g., add import
pytest and import torch if missing, then decorate the test with
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available -
skip on CPU-only runners") or programmatically call pytest.skip at the start of
the test when torch.cuda.is_available() is False) so the test does not run on CI
machines without GPU.

…DA codegen

* Implemented handling for conversions between float16 and float32 types, specifically for vectorized operations using __half22float2 and __float22half2_rn.
* Enhanced the existing code to support both directions of conversion based on the lane count.
* Improved overall type handling in the VisitExpr_ method for better compatibility with TileLang.
* Implemented handling for conversion from float32 to float8 (E4M3/E5M2) in the VisitExpr_ method.
* Added vectorized conversion support using __nv_cvt_float2_to_fp8x2 for float2 to fp8x2 transformations.
* Enhanced type handling for better compatibility with TileLang, particularly for float8 types.

# bf16 -> fp32
run_vectorized_cast("bfloat16", "float32", "fastertransformer", 2)
# run_vectorized_cast("bfloat16", "float32", "fastertransformer", 4)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

has this been resolved?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Supported now, ready to be merged once ci passed :)

has this been resolved?

@Rachmanino
Copy link
Collaborator Author

cc @xwhzz

@Rachmanino Rachmanino requested a review from xwhzz October 22, 2025 15:05
@xwhzz
Copy link
Contributor

xwhzz commented Oct 23, 2025

We may no longer need the files cuda_bf16_fallbacks.cuh and cuda_bf16_wrapper.h in the src/tl_templates/cuda directory, and the codegen no longer needs to print their #include directives for alignment.

@Rachmanino
Copy link
Collaborator Author

We may no longer need the files cuda_bf16_fallbacks.cuh and cuda_bf16_wrapper.h in the src/tl_templates/cuda directory, and the codegen no longer needs to print their #include directives for alignment.

But there's some other utility functions in cuda_bf16_fallbacks.cuh . I wonder whether we'll need them later for other purpose.

@xwhzz
Copy link
Contributor

xwhzz commented Oct 23, 2025

We may no longer need the files cuda_bf16_fallbacks.cuh and cuda_bf16_wrapper.h in the src/tl_templates/cuda directory, and the codegen no longer needs to print their #include directives for alignment.

But there's some other utility functions in cuda_bf16_fallbacks.cuh . I wonder whether we'll need them later for other purpose.

We can leave it for now. The rest of the changes LGTM.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants