Skip to content

Conversation

@chengyupku
Copy link
Contributor

@chengyupku chengyupku commented Oct 21, 2025

This pull request introduces support for TMA Im2Col descriptor initialization and ensures that all relevant code paths and IR transformations handle the new tma_load_im2col operation. Additionally, it improves user feedback in the convolution example.

Support for TMA Im2Col descriptors:

  • Added the TMA_IM2COL_DESC_INIT_FUNC string template for initializing TMA Im2Col descriptors in tilelang/jit/adapter/wrapper.py.
  • Updated the descriptor argument generation logic in generate_tma_descriptor_args to correctly parse and initialize Im2Col descriptors, including validation and extraction of Im2Col-specific parameters. [1] [2]

IR and transformation updates for Im2Col:

  • Modified IR mutators and visitors in src/transform/inject_tma_barrier.cc to recognize and handle both tma_load and tma_load_im2col operations, ensuring correct barrier injection and rewriting. [1] [2] [3]

User experience improvement:

  • Added a print statement to the convolution example to notify users when all checks have passed.

Summary by CodeRabbit

  • New Features

    • Added support for IM2COL tensor operations with enhanced TMA descriptor initialization for memory access patterns.
  • Improvements

    • Example scripts now display clear success confirmation messages upon passing validation checks.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 21, 2025

Walkthrough

Support for tma_load_im2col() and create_tma_im2col_descriptor() is added across the codebase. Barrier handling logic is extended to treat im2col TMA operations equivalently to standard TMA operations. A new im2col descriptor initialization function constant is introduced, and descriptor argument generation is refactored to branch on im2col type.

Changes

Cohort / File(s) Summary
Example verification
examples/convolution/example_convolution.py
Added success confirmation message print after validation assertion.
TMA barrier handling
src/transform/inject_tma_barrier.cc
Extended all-TMA-load checks and barrier logic (TmaExpectTxRewriter, TmaBarrierCollector, TmaBarrierRewriter) to include tma_load_im2col(). Updated 1D TMA load detection to exclude both create_tma_descriptor and create_tma_im2col_descriptor.
Descriptor initialization
tilelang/jit/adapter/wrapper.py
Introduced TMA_IM2COL_DESC_INIT_FUNC constant for im2col descriptor construction. Refactored generate_tma_descriptor_args() to branch on is_img2col flag, extracting different parameters and formatting appropriate descriptor initialization code for each path.

Estimated code review effort

🎯 3 (Moderate) | ⏱️ ~20 minutes

The changes span multiple files with mixed complexity: trivial output addition in examples, high-density logic extensions in barrier handling (requiring careful verification of all conditional branches), and medium-complexity conditional branching in descriptor generation. The heterogeneous nature and interconnected im2col support across subsystems warrant careful cross-file verification.

Possibly related PRs

  • tile-ai/tilelang#761: Both PRs modify inject_tma_barrier.cc's TMA barrier/argument-handling logic, with this PR adding tma_load_im2col() support while that PR adds 1D TMA handling.
  • tile-ai/tilelang#744: Direct overlap in barrier-management refactoring that introduces tma_load_im2col() and create_tma_im2col_descriptor() alongside barrier-related logic updates.

Suggested reviewers

  • LeiWang1999

Poem

🐰✨ A hop through descriptor space so wide,
IM2COL convolutions now tagged with pride,
Barriers bend to the new op's call,
Wrapped and threaded—we support it all! 🎯

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 0.00% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The PR title "[Bugfix] Fix missing host cuTensorMapEncodeIm2col call" directly relates to a core component of the changeset. According to the raw summary, the new TMA_IM2COL_DESC_INIT_FUNC constant in wrapper.py includes "a call to cuTensorMapEncodeIm2col with error handling," which is being added to support IM2COL descriptor initialization. The title is specific, clear, and identifies a key fix within the larger feature of adding IM2COL support. While the changeset includes additional updates to IR transformations and a minor user feedback improvement, the title appropriately focuses on the primary fix without needing to enumerate every detail.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🧹 Nitpick comments (4)
src/transform/inject_tma_barrier.cc (2)

165-179: Treat 1D‑load detection consistently across passes (cosmetic).

Here you special‑case 1D only for tma_load(). For consistency with TmaBarrierRewriter, also exclude create_tma_im2col_descriptor() when deciding 1D. No behavior change, just symmetry.

-      bool is_1d_tma_load =
-          arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
-          op->op.same_as(tma_load());
+      bool is_1d_tma_load =
+          arg0 &&
+          !arg0.value()->op.same_as(create_tma_descriptor()) &&
+          !arg0.value()->op.same_as(create_tma_im2col_descriptor()) &&
+          op->op.same_as(tma_load());

454-495: Minor: clarify error text to include im2col.

Message currently says “tma_load must be …” but applies to both tma_load and tma_load_im2col. Suggest rewording.

-      ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
-          << "tma_load must be in the tma_op_to_barrier_id_";
+      ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
+          << "TMA load op must be in tma_op_to_barrier_id_ (tma_load or tma_load_im2col).";
tilelang/jit/adapter/wrapper.py (2)

109-136: Confirm IM2COL param order and stride semantics.

  • channelsPerPixel vs pixelsPerColumn: you unpack (pixel, channel, …) but pass (channel, pixel). The names suggest this is intentional; please confirm mapping matches the CUDA driver signature to avoid silent layout bugs.
  • globalStride + 1: verify cuTensorMapEncodeIm2col expects rank‑1 stride array like the tiled path.
  • Rank guard: im2col is 4D in our kernels; add a check to enforce tensor_rank >= 4 before formatting.
-            is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
+            is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
+            if is_img2col and tensor_rank < 4:
+                raise ValueError(f"Invalid tensor_rank for im2col: {tensor_rank}. Expected >= 4")

471-474: Minor: long exception messages (TRY003).

A few raise ValueError blocks embed long messages inline. Consider shorter messages or constants to satisfy linting, but this is non‑blocking.

Based on static analysis hints.

Also applies to: 483-484, 509-511

📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between bddb125 and 544b21e.

📒 Files selected for processing (3)
  • examples/convolution/example_convolution.py (1 hunks)
  • src/transform/inject_tma_barrier.cc (3 hunks)
  • tilelang/jit/adapter/wrapper.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/inject_tma_barrier.cc (3)
src/transform/warp_specialized_rewriter.cc (18)
  • op (38-43)
  • op (38-38)
  • op (73-83)
  • op (73-73)
  • op (85-93)
  • op (85-85)
  • op (95-100)
  • op (95-95)
  • op (102-110)
  • op (102-102)
  • op (134-146)
  • op (134-134)
  • op (148-177)
  • op (148-148)
  • op (179-189)
  • op (179-179)
  • call (31-36)
  • call (31-31)
src/tl_templates/cuda/copy_sm90.h (7)
  • tma_load (18-27)
  • tma_load (43-60)
  • tma_load (64-82)
  • tma_load (86-104)
  • tma_load (107-126)
  • tma_load (130-150)
  • tma_load_im2col (154-172)
src/transform/lower_hopper_intrin.cc (2)
  • call (102-132)
  • call (102-102)
🪛 Ruff (0.14.1)
tilelang/jit/adapter/wrapper.py

442-442: Avoid specifying long messages outside the exception class

(TRY003)


448-449: Avoid specifying long messages outside the exception class

(TRY003)


471-473: Avoid specifying long messages outside the exception class

(TRY003)


483-484: Avoid specifying long messages outside the exception class

(TRY003)


509-511: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (2)
src/transform/inject_tma_barrier.cc (1)

204-229: Collector update for im2col looks good.

Detection and tracking of tma_load_im2col() align with existing tma_load() handling. No issues spotted.

examples/convolution/example_convolution.py (1)

125-125: Nice UX touch.

Print after successful assertion is helpful and side‑effect free.

Comment on lines +434 to +517
tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args

is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
dtype = self._pythonic_expr(dtype)
tensor_rank = int(self._pythonic_expr(tensor_rank))

# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")

# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]

global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]

# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
if not is_img2col:
# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]

global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]

# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e

tma_descripter_init += TMA_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(box_dim), ",".join(element_strides),
interleave, swizzle, l2Promotion, oobFill)
else:
# Calculate required length for remaining_args
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank]
lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
element_strides = [self._pythonic_expr(i) for i in element_strides]
lower_corner = [self._pythonic_expr(i) for i in lower_corner]
upper_corner = [self._pythonic_expr(i) for i in upper_corner]

# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[
5 * tensor_rank - 4:5 * tensor_rank + 2]
smem_box_pixel = self._pythonic_expr(smem_box_pixel)
smem_box_channel = self._pythonic_expr(smem_box_channel)
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
) from e

tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(element_strides), ",".join(lower_corner),
",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle,
l2Promotion, oobFill)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

NVRTC path will mis-parse/new layout and lacks im2col support.

You changed the descriptor-arg layout (tma_create_str first) and added im2col handling for the CUDA host path, but TLNVRTCSourceWrapper.generate_tma_descriptor_args still:

  • assumes the old layout (uses args[1:]) and
  • only supports cuTensorMapEncodeTiled, not cuTensorMapEncodeIm2col.

This will break the NVRTC backend and Python wrapper when tma_descriptor_args follow the new layout or when im2col is used.

Apply the following updates:

  • Add a Python IM2COL init template.
  • Parse the new layout (no args[1:]).
  • Branch on create_im2col and call cuTensorMapEncodeIm2col.
*** a/tilelang/jit/adapter/wrapper.py
@@
 TMA_DESC_INIT_FUNC_PY = """
@@
 """
+
+TMA_IM2COL_DESC_INIT_FUNC_PY = """
+\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
+\t{0}_tensorRank = {2}
+\t{0}_globalAddress = {3}.data_ptr()
+\t{0}_globalDim = [{4}]
+\t{0}_globalStride = [{5}][1:]
+\t{0}_elementStrides = [{6}]
+\t{0}_lowerCorner = [{7}]
+\t{0}_upperCorner = [{8}]
+\t{0}_channelsPerPixel = {9}
+\t{0}_pixelsPerColumn = {10}
+\t{0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({11})
+\t{0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({12})
+\t{0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({13})
+\t{0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({14})
+
+\tres, {0} = cuda.bindings.driver.cuTensorMapEncodeIm2col(
+\t\t{0}_type,
+\t\t{0}_tensorRank,
+\t\t{0}_globalAddress,
+\t\t{0}_globalDim,
+\t\t{0}_globalStride,
+\t\t{0}_lowerCorner,
+\t\t{0}_upperCorner,
+\t\t{0}_channelsPerPixel,
+\t\t{0}_pixelsPerColumn,
+\t\t{0}_elementStrides,
+\t\t{0}_interleave,
+\t\t{0}_swizzle,
+\t\t{0}_l2Promotion,
+\t\t{0}_oobFill,
+\t)
+\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
+\t\traise RuntimeError(f"Failed to initialize the TMA im2col descriptor {0}: {res}")
+"""
@@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
     def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
@@
-            _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
+            # New layout: [tma_create_str, <desc or dummy>, dtype, tensor_rank, globalAddress, ...]
+            tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args
+            is_im2col = (getattr(tma_create_str, "value", tma_create_str) == "__tvm_tensormap_create_im2col")
@@
-            # Calculate required length for remaining_args
-            # 4 groups of tensor_rank size + 4 parameters
-            expected_args_len = 4 * tensor_rank + 4
-            if len(remaining_args) < expected_args_len:
-                raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
-                                 f"expected {expected_args_len} for tensor_rank {tensor_rank}")
-
-            # Extract dimensions and strides using list slicing
-            global_dim = remaining_args[:tensor_rank]
-            global_stride = remaining_args[tensor_rank:2 * tensor_rank]
-            box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
-            element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
-
-            global_dim = [str(i) for i in global_dim]
-            global_stride = [str(i) for i in global_stride]
-            box_dim = [str(i) for i in box_dim]
-            element_strides = [str(i) for i in element_strides]
-
-            # Extract remaining parameters
-            try:
-                interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
-                                                                           tensor_rank + 4]
-            except ValueError as e:
-                raise ValueError(
-                    "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
-                ) from e
-
-            tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
-                handle_name, dtype, tensor_rank, globalAddress,
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_dim)),
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_stride)),
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", box_dim)),
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})",
-                              element_strides)), interleave, swizzle, l2Promotion, oobFill)
+            if not is_im2col:
+                expected_args_len = 4 * tensor_rank + 4
+                if len(remaining_args) < expected_args_len:
+                    raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}")
+                global_dim      = remaining_args[:tensor_rank]
+                global_stride   = remaining_args[tensor_rank:2 * tensor_rank]
+                box_dim         = remaining_args[2 * tensor_rank:3 * tensor_rank]
+                element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
+                interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * tensor_rank + 4]
+                tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
+                    handle_name, dtype, tensor_rank, globalAddress,
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_dim))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_stride))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, box_dim))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, element_strides))),
+                    interleave, swizzle, l2Promotion, oobFill)
+            else:
+                if tensor_rank < 4:
+                    raise ValueError(f"Invalid tensor_rank for im2col: {tensor_rank}. Expected >= 4")
+                expected_args_len = 5 * tensor_rank + 2
+                if len(remaining_args) < expected_args_len:
+                    raise ValueError(f"Insufficient remaining args for im2col: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}")
+                global_dim       = remaining_args[:tensor_rank]
+                global_stride    = remaining_args[tensor_rank:2 * tensor_rank]
+                element_strides  = remaining_args[2 * tensor_rank:3 * tensor_rank]
+                lower_corner     = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
+                upper_corner     = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
+                smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2]
+                tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format(
+                    handle_name, dtype, tensor_rank, globalAddress,
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_dim))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_stride))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, element_strides))),
+                    ", ".join(map(str, lower_corner)),
+                    ", ".join(map(str, upper_corner)),
+                    smem_box_channel,  # channelsPerPixel
+                    smem_box_pixel,    # pixelsPerColumn
+                    interleave, swizzle, l2Promotion, oobFill)

Please confirm the NVRTC bindings expose cuTensorMapEncodeIm2col. If not, we’ll need to add it to the Python driver first.


🏁 Script executed:

#!/bin/bash
rg -n "cuTensorMapEncodeIm2col|cuTensorMapEncodeTiled|TLNVRTCSourceWrapper" -C2

Length of output: 4322


NVRTC path will break: lacks im2col support and uses old argument layout.

The NVRTC backend (TLNVRTCSourceWrapper) was not updated alongside the CUDA host changes. Two critical gaps:

  1. Layout mismatch: NVRTC still assumes the old argument order (args[1:]), but the new CUDA path now parses tma_create_str first, without the args[1:] offset.
  2. Missing im2col: Only cuTensorMapEncodeTiled is available in the Python templates; im2col support is absent.

The proposed patch correctly adds TMA_IM2COL_DESC_INIT_FUNC_PY and updates argument parsing for both branches. However, verify that cuda.bindings provides full coverage of and 1:1 access to the CUDA host APIs from Python, which should include cuTensorMapEncodeIm2col. If this binding is unavailable in the installed version, it must be added to the Python driver first before the patch can work.

🧰 Tools
🪛 Ruff (0.14.1)

442-442: Avoid specifying long messages outside the exception class

(TRY003)


448-449: Avoid specifying long messages outside the exception class

(TRY003)


471-473: Avoid specifying long messages outside the exception class

(TRY003)


483-484: Avoid specifying long messages outside the exception class

(TRY003)


509-511: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
tilelang/jit/adapter/wrapper.py lines 434-517: NVRTC wrapper is out of sync with
CUDA host changes — it still assumes the old args[1:] layout and lacks im2col
support; update the NVRTC/TLNVRTCSourceWrapper parsing to consume tma_create_str
as the first element (same as the CUDA path) instead of offsetting by 1, add the
im2col handling path using the new TMA_IM2COL_DESC_INIT_FUNC_PY template
(mirroring the non-im2col branch’s slicing/validation logic but with the
im2col-specific slices and final 6 parameters), and ensure the Python
cuda.bindings layer exposes cuTensorMapEncodeIm2col (or add that binding) before
enabling the im2col branch so the template maps 1:1 to the host API.

@LeiWang1999
Copy link
Member

@codex review

Copy link

@chatgpt-codex-connector chatgpt-codex-connector bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

💡 Codex Review

Here are some automated review suggestions for this pull request.

ℹ️ About Codex in GitHub

Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you

  • Open a pull request for review
  • Mark a draft as ready
  • Comment "@codex review".

If Codex has suggestions, it will comment; otherwise it will react with 👍.

Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".

Comment on lines 441 to 519
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")

# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]

global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]

# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
if not is_img2col:
# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]

global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]

# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e

tma_descripter_init += TMA_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(box_dim), ",".join(element_strides),
interleave, swizzle, l2Promotion, oobFill)
else:
# Calculate required length for remaining_args
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank]
lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
element_strides = [self._pythonic_expr(i) for i in element_strides]
lower_corner = [self._pythonic_expr(i) for i in lower_corner]
upper_corner = [self._pythonic_expr(i) for i in upper_corner]

# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[
5 * tensor_rank - 4:5 * tensor_rank + 2]
smem_box_pixel = self._pythonic_expr(smem_box_pixel)
smem_box_channel = self._pythonic_expr(smem_box_channel)
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
) from e

tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(element_strides), ",".join(lower_corner),
",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle,
l2Promotion, oobFill)

tma_descripter_init += TMA_DESC_INIT_FUNC.format(handle_name, dtype, tensor_rank,
globalAddress, ",".join(global_dim),
",".join(global_stride),
",".join(box_dim),
",".join(element_strides), interleave,
swizzle, l2Promotion, oobFill)
return tma_descripter_init

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge NVRTC path still emits tiled descriptor for IM2COL

The new branch in generate_tma_descriptor_args only teaches the CUDA driver wrapper how to parse IM2COL descriptors and dispatch cuTensorMapEncodeIm2col, but the NVRTC backend (TLNVRTCSourceWrapper.generate_tma_descriptor_args around lines 840‑894) was left untouched. That path still assumes the old argument layout and always calls cuTensorMapEncodeTiled, so any kernel compiled via NVRTC that produces a __tvm_tensormap_create_im2col descriptor will either raise the "Insufficient remaining args" ValueError or encode the descriptor with the wrong CUDA API. This means IM2COL kernels will fail when using the NVRTC runtime, which is a common execution mode.

Useful? React with 👍 / 👎.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can support in future.

@LeiWang1999 LeiWang1999 merged commit 5cb5c06 into tile-ai:main Oct 21, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants