- 
                Notifications
    You must be signed in to change notification settings 
- Fork 292
[Feature] Enhance vectorized conversion support in CUDA codegen #1095
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
| WalkthroughReworks CUDA CastNode codegen to add explicit 2- and 4-lane vectorized conversion branches for float16↔float32, bfloat16↔float32, and float32↔FP8 (E4M3/E5M2) using CUDA intrinsics; removes prior bf16-op path. Adds tests for vectorized casts and removes  Changes
 Sequence Diagram(s)sequenceDiagram
  participant IR as CastNode (codegen)
  participant Detect as Type/Width Detector
  participant Emit as CUDA Emitter
  participant Fallback as Elementwise Emitter
  rect rgb(245,250,255)
  Note over IR,Detect: Prefer 2-/4-lane vectorized paths when applicable
  end
  IR->>Detect: inspect src/dst types and lane count
  alt float16↔float32 (2/4 lanes)
    Detect-->>IR: match half↔float vector
    IR->>Emit: emit `__half22float2` / 4-lane equivalent or `__float22half2_rn`
    Emit-->>IR: write vectorized store to sret
    IR->>IR: return early
  else bfloat16↔float32 (2/4 lanes)
    Detect-->>IR: match bf16↔float vector
    IR->>Emit: emit `__bfloat1622float2` / `__float22bfloat162_rn`
    Emit-->>IR: write vectorized store to sret
    IR->>IR: return early
  else float32↔FP8 (2/4 lanes)
    Detect-->>IR: match float↔FP8 vector
    IR->>Emit: emit `__nv_cvt_float2_to_fp8x2` / 4-lane equivalents
    Emit-->>IR: write vectorized store to sret
    IR->>IR: return early
  else no vectorized path
    Detect-->>IR: no vectorized match
    IR->>Fallback: emit per-element cast loop
    Fallback-->>IR: elementwise stores emitted
  end
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Suggested reviewers
 Poem
 Pre-merge checks and finishing touches❌ Failed checks (1 warning)
 ✅ Passed checks (2 passed)
 ✨ Finishing touches
 🧪 Generate unit tests (beta)
 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (7)
 🚧 Files skipped from review as they are similar to previous changes (4)
 🧰 Additional context used🧬 Code graph analysis (1)testing/python/language/test_tilelang_language_vectorized_cast.py (2)
 🪛 Ruff (0.14.1)testing/python/language/test_tilelang_language_vectorized_cast.py20-20: Unused  Remove unused  (RUF100) 21-21: Unused  Remove unused  (RUF100) ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
 🔇 Additional comments (5)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (1)
- src/target/codegen_cuda.cc(2 hunks)
🔇 Additional comments (3)
src/target/codegen_cuda.cc (3)
903-911: Vectorized float16→float32 conversion looks correct.The implementation properly uses the
__half22float2intrinsic for 2-lane conversions and returns early to avoid the elementwise fallback. The pointer casting pattern*(half2*)(&(src))is standard CUDA practice for vector type reinterpretation.
912-921: Vectorized float32→float16 conversion looks correct.The implementation properly uses the
__float22half2_rnintrinsic (round-to-nearest mode) for the reverse conversion. The bidirectional support for float16↔float32 is complete and symmetric.
986-986: Helpful clarification comment.This comment appropriately documents why BF16 handling differs from the FP16/FP8 vectorized paths. It clarifies that BF16 intentionally falls through to the elementwise cast fallback rather than returning early, which helps maintainers understand the design decision.
| We're good to go if we can add a simple unit codegen test for this case. | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (3)
testing/python/language/test_tilelang_language_vectorized_cast.py (3)
53-54: Compare in fp32 to avoid dtype-specific assert quirks (esp. float8/bf16).Casts involving f8/bf16 are most stable when compared after upcasting to fp32.
- torch.testing.assert_close(A.to(str2dtype[dst_dtype_str]), B) + torch.testing.assert_close( + A.to(str2dtype[dst_dtype_str]).to(torch.float32), + B.to(torch.float32) + )
55-59: Make codegen token check robust across backends/versions.Allow multiple acceptable substrings and assert any-match; also improve failure message.
- code = kernel.get_kernel_source() - - assert check_str in code, \ - f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" + code = kernel.get_kernel_source() + tokens = (check_str,) if isinstance(check_str, str) else tuple(check_str) + assert any(t in code for t in tokens), ( + f"Cast {src_dtype_str}->{dst_dtype_str} lanes={lanes} not vectorized. " + f"Expected one of: {tokens}" + )If needed, pass a tuple for bf16 (e.g., (“fastertransformer”, “cvt.rn.bf16x2”, “ld.global.v2.b16”)) to tolerate NVRTC/PTX differences.
70-77: Optional: stabilize float8/bf16 inputs.To reduce rare saturation/rounding edge cases, consider clamping inputs for f8 paths (e.g., to [-16, 16]) or seeding RNG.
Also applies to: 81-86
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
- src/target/codegen_cuda.cc(2 hunks)
- testing/python/language/test_tilelang_language_vectorized_cast.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- src/target/codegen_cuda.cc
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (2)
tilelang/jit/__init__.py (1)
jit(237-310)tilelang/language/copy.py (1)
copy(10-86)
🪛 Ruff (0.14.1)
testing/python/language/test_tilelang_language_vectorized_cast.py
21-21: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
22-22: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
🔇 Additional comments (1)
testing/python/language/test_tilelang_language_vectorized_cast.py (1)
24-26: LGTM: minimal kernel to exercise vectorized copy.The single-block, 128-thread kernel with T.copy is a clean trigger for the vectorized codegen paths.
| A: T.Tensor[(M), dtype_A], # noqa: F821 | ||
| B: T.Tensor[(M), dtype_B], # noqa: F821 | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove unused noqa pragmas.
# noqa: F821 is unnecessary per Ruff; drop both to keep lint clean.
As per static analysis hints.
-        A: T.Tensor[(M), dtype_A],  # noqa: F821
-        B: T.Tensor[(M), dtype_B],  # noqa: F821
+        A: T.Tensor[(M), dtype_A],
+        B: T.Tensor[(M), dtype_B],📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| A: T.Tensor[(M), dtype_A], # noqa: F821 | |
| B: T.Tensor[(M), dtype_B], # noqa: F821 | |
| A: T.Tensor[(M), dtype_A], | |
| B: T.Tensor[(M), dtype_B], | 
🧰 Tools
🪛 Ruff (0.14.1)
21-21: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
22-22: Unused noqa directive (unused: F821)
Remove unused noqa directive
(RUF100)
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around
lines 21 to 22, the two inline "# noqa: F821" pragmas on the A and B Tensor type
lines are unnecessary; remove the "# noqa: F821" suffix from both lines so the
linter no longer shows unused noqa pragmas and the file remains lint-clean.
| """Run the vectorized cast kernel and check the correctness. | ||
| Args: | ||
| src_dtype_str: The source data type string. | ||
| dst_dtype_str: The destination data type string. | ||
| check_str: Used to ensure vectorized cast is used. | ||
| M: The size of the tensor. | ||
| lanes: The number of lanes of the source and destination data types. | ||
| """ | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix docstring: remove nonexistent “M” arg.
“M” isn’t a parameter; it’s derived. Clarify to avoid confusion.
-        M: The size of the tensor.
-        lanes: The number of lanes of the source and destination data types.
+        lanes: The number of lanes of the source and destination data types.
+        Note: M is computed as 128 * lanes to satisfy kernel constraints.📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| """Run the vectorized cast kernel and check the correctness. | |
| Args: | |
| src_dtype_str: The source data type string. | |
| dst_dtype_str: The destination data type string. | |
| check_str: Used to ensure vectorized cast is used. | |
| M: The size of the tensor. | |
| lanes: The number of lanes of the source and destination data types. | |
| """ | |
| """Run the vectorized cast kernel and check the correctness. | |
| Args: | |
| src_dtype_str: The source data type string. | |
| dst_dtype_str: The destination data type string. | |
| check_str: Used to ensure vectorized cast is used. | |
| lanes: The number of lanes of the source and destination data types. | |
| Note: M is computed as 128 * lanes to satisfy kernel constraints. | |
| """ | 
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around
lines 36 to 43, the docstring incorrectly lists an argument "M" which does not
exist; remove "M" from the args list and update the description to state that
the tensor size is derived (or computed internally) instead of being an input
parameter, keeping the rest of the docstring intact and consistent with actual
function parameters.
| f"Cast {src_dtype_str} to {dst_dtype_str} with {lanes=} is not vectorized!" | ||
|  | ||
|  | ||
| def test_vectorized_cast(): | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Skip on CPU-only runners.
Guard the test to avoid CI failures when CUDA isn’t available.
+import pytest
@@
+@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required")
 def test_vectorized_cast():📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| def test_vectorized_cast(): | |
| import pytest | |
| @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA required") | |
| def test_vectorized_cast(): | 
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_vectorized_cast.py around line
61, the test function test_vectorized_cast must be skipped on CPU-only runners;
add a guard using pytest to skip when CUDA is unavailable (e.g., add import
pytest and import torch if missing, then decorate the test with
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available -
skip on CPU-only runners") or programmatically call pytest.skip at the start of
the test when torch.cuda.is_available() is False) so the test does not run on CI
machines without GPU.
…DA codegen * Implemented handling for conversions between float16 and float32 types, specifically for vectorized operations using __half22float2 and __float22half2_rn. * Enhanced the existing code to support both directions of conversion based on the lane count. * Improved overall type handling in the VisitExpr_ method for better compatibility with TileLang.
* Implemented handling for conversion from float32 to float8 (E4M3/E5M2) in the VisitExpr_ method. * Added vectorized conversion support using __nv_cvt_float2_to_fp8x2 for float2 to fp8x2 transformations. * Enhanced type handling for better compatibility with TileLang, particularly for float8 types.
|  | ||
| # bf16 -> fp32 | ||
| run_vectorized_cast("bfloat16", "float32", "fastertransformer", 2) | ||
| # run_vectorized_cast("bfloat16", "float32", "fastertransformer", 4) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
has this been resolved?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Supported now, ready to be merged once ci passed :)
has this been resolved?
| cc @xwhzz | 
| We may no longer need the files  | 
| 
 But there's some other utility functions in  | 
| 
 We can leave it for now. The rest of the changes LGTM. | 
Finished #1034
Add features:
Summary by CodeRabbit
New Features
Tests
Chores / Style