-
Notifications
You must be signed in to change notification settings - Fork 290
[Bugfix] Fix missing host cuTensorMapEncodeIm2col call
#1094
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughSupport for Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes The changes span multiple files with mixed complexity: trivial output addition in examples, high-density logic extensions in barrier handling (requiring careful verification of all conditional branches), and medium-complexity conditional branching in descriptor generation. The heterogeneous nature and interconnected im2col support across subsystems warrant careful cross-file verification. Possibly related PRs
Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
src/transform/inject_tma_barrier.cc (2)
165-179: Treat 1D‑load detection consistently across passes (cosmetic).Here you special‑case 1D only for tma_load(). For consistency with TmaBarrierRewriter, also exclude create_tma_im2col_descriptor() when deciding 1D. No behavior change, just symmetry.
- bool is_1d_tma_load = - arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) && - op->op.same_as(tma_load()); + bool is_1d_tma_load = + arg0 && + !arg0.value()->op.same_as(create_tma_descriptor()) && + !arg0.value()->op.same_as(create_tma_im2col_descriptor()) && + op->op.same_as(tma_load());
454-495: Minor: clarify error text to include im2col.Message currently says “tma_load must be …” but applies to both tma_load and tma_load_im2col. Suggest rewording.
- ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) - << "tma_load must be in the tma_op_to_barrier_id_"; + ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op))) + << "TMA load op must be in tma_op_to_barrier_id_ (tma_load or tma_load_im2col).";tilelang/jit/adapter/wrapper.py (2)
109-136: Confirm IM2COL param order and stride semantics.
- channelsPerPixel vs pixelsPerColumn: you unpack (pixel, channel, …) but pass (channel, pixel). The names suggest this is intentional; please confirm mapping matches the CUDA driver signature to avoid silent layout bugs.
- globalStride + 1: verify cuTensorMapEncodeIm2col expects rank‑1 stride array like the tiled path.
- Rank guard: im2col is 4D in our kernels; add a check to enforce tensor_rank >= 4 before formatting.
- is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") + is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") + if is_img2col and tensor_rank < 4: + raise ValueError(f"Invalid tensor_rank for im2col: {tensor_rank}. Expected >= 4")
471-474: Minor: long exception messages (TRY003).A few raise ValueError blocks embed long messages inline. Consider shorter messages or constants to satisfy linting, but this is non‑blocking.
Based on static analysis hints.
Also applies to: 483-484, 509-511
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/convolution/example_convolution.py(1 hunks)src/transform/inject_tma_barrier.cc(3 hunks)tilelang/jit/adapter/wrapper.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
src/transform/inject_tma_barrier.cc (3)
src/transform/warp_specialized_rewriter.cc (18)
op(38-43)op(38-38)op(73-83)op(73-73)op(85-93)op(85-85)op(95-100)op(95-95)op(102-110)op(102-102)op(134-146)op(134-134)op(148-177)op(148-148)op(179-189)op(179-179)call(31-36)call(31-31)src/tl_templates/cuda/copy_sm90.h (7)
tma_load(18-27)tma_load(43-60)tma_load(64-82)tma_load(86-104)tma_load(107-126)tma_load(130-150)tma_load_im2col(154-172)src/transform/lower_hopper_intrin.cc (2)
call(102-132)call(102-102)
🪛 Ruff (0.14.1)
tilelang/jit/adapter/wrapper.py
442-442: Avoid specifying long messages outside the exception class
(TRY003)
448-449: Avoid specifying long messages outside the exception class
(TRY003)
471-473: Avoid specifying long messages outside the exception class
(TRY003)
483-484: Avoid specifying long messages outside the exception class
(TRY003)
509-511: Avoid specifying long messages outside the exception class
(TRY003)
🔇 Additional comments (2)
src/transform/inject_tma_barrier.cc (1)
204-229: Collector update for im2col looks good.Detection and tracking of tma_load_im2col() align with existing tma_load() handling. No issues spotted.
examples/convolution/example_convolution.py (1)
125-125: Nice UX touch.Print after successful assertion is helpful and side‑effect free.
| tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args | ||
|
|
||
| is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col") | ||
| dtype = self._pythonic_expr(dtype) | ||
| tensor_rank = int(self._pythonic_expr(tensor_rank)) | ||
|
|
||
| # Validate tensor_rank | ||
| if not isinstance(tensor_rank, int) or tensor_rank <= 0: | ||
| raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer") | ||
|
|
||
| # Calculate required length for remaining_args | ||
| expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters | ||
| if len(remaining_args) < expected_args_len: | ||
| raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " | ||
| f"expected {expected_args_len} for tensor_rank {tensor_rank}") | ||
|
|
||
| # Extract dimensions and strides using list slicing | ||
| global_dim = remaining_args[:tensor_rank] | ||
| global_stride = remaining_args[tensor_rank:2 * tensor_rank] | ||
| box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] | ||
| element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] | ||
|
|
||
| global_dim = [self._pythonic_expr(i) for i in global_dim] | ||
| global_stride = [self._pythonic_expr(i) for i in global_stride] | ||
| box_dim = [self._pythonic_expr(i) for i in box_dim] | ||
| element_strides = [self._pythonic_expr(i) for i in element_strides] | ||
|
|
||
| # Extract remaining parameters | ||
| try: | ||
| interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * | ||
| tensor_rank + 4] | ||
| interleave = self._pythonic_expr(interleave) | ||
| swizzle = self._pythonic_expr(swizzle) | ||
| l2Promotion = self._pythonic_expr(l2Promotion) | ||
| oobFill = self._pythonic_expr(oobFill) | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" | ||
| ) from e | ||
| if not is_img2col: | ||
| # Calculate required length for remaining_args | ||
| expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters | ||
| if len(remaining_args) < expected_args_len: | ||
| raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " | ||
| f"expected {expected_args_len} for tensor_rank {tensor_rank}") | ||
|
|
||
| # Extract dimensions and strides using list slicing | ||
| global_dim = remaining_args[:tensor_rank] | ||
| global_stride = remaining_args[tensor_rank:2 * tensor_rank] | ||
| box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] | ||
| element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] | ||
|
|
||
| global_dim = [self._pythonic_expr(i) for i in global_dim] | ||
| global_stride = [self._pythonic_expr(i) for i in global_stride] | ||
| box_dim = [self._pythonic_expr(i) for i in box_dim] | ||
| element_strides = [self._pythonic_expr(i) for i in element_strides] | ||
|
|
||
| # Extract remaining parameters | ||
| try: | ||
| interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * | ||
| tensor_rank + 4] | ||
| interleave = self._pythonic_expr(interleave) | ||
| swizzle = self._pythonic_expr(swizzle) | ||
| l2Promotion = self._pythonic_expr(l2Promotion) | ||
| oobFill = self._pythonic_expr(oobFill) | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" | ||
| ) from e | ||
|
|
||
| tma_descripter_init += TMA_DESC_INIT_FUNC.format( | ||
| handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), | ||
| ",".join(global_stride), ",".join(box_dim), ",".join(element_strides), | ||
| interleave, swizzle, l2Promotion, oobFill) | ||
| else: | ||
| # Calculate required length for remaining_args | ||
| expected_args_len = 5 * tensor_rank + 2 | ||
| if len(remaining_args) < expected_args_len: | ||
| raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " | ||
| f"expected {expected_args_len} for tensor_rank {tensor_rank}") | ||
|
|
||
| # Extract dimensions and strides using list slicing | ||
| global_dim = remaining_args[:tensor_rank] | ||
| global_stride = remaining_args[tensor_rank:2 * tensor_rank] | ||
| element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank] | ||
| lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2] | ||
| upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4] | ||
| global_dim = [self._pythonic_expr(i) for i in global_dim] | ||
| global_stride = [self._pythonic_expr(i) for i in global_stride] | ||
| element_strides = [self._pythonic_expr(i) for i in element_strides] | ||
| lower_corner = [self._pythonic_expr(i) for i in lower_corner] | ||
| upper_corner = [self._pythonic_expr(i) for i in upper_corner] | ||
|
|
||
| # Extract remaining parameters | ||
| try: | ||
| smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[ | ||
| 5 * tensor_rank - 4:5 * tensor_rank + 2] | ||
| smem_box_pixel = self._pythonic_expr(smem_box_pixel) | ||
| smem_box_channel = self._pythonic_expr(smem_box_channel) | ||
| interleave = self._pythonic_expr(interleave) | ||
| swizzle = self._pythonic_expr(swizzle) | ||
| l2Promotion = self._pythonic_expr(l2Promotion) | ||
| oobFill = self._pythonic_expr(oobFill) | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| "Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)" | ||
| ) from e | ||
|
|
||
| tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( | ||
| handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), | ||
| ",".join(global_stride), ",".join(element_strides), ",".join(lower_corner), | ||
| ",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle, | ||
| l2Promotion, oobFill) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
NVRTC path will mis-parse/new layout and lacks im2col support.
You changed the descriptor-arg layout (tma_create_str first) and added im2col handling for the CUDA host path, but TLNVRTCSourceWrapper.generate_tma_descriptor_args still:
- assumes the old layout (uses args[1:]) and
- only supports cuTensorMapEncodeTiled, not cuTensorMapEncodeIm2col.
This will break the NVRTC backend and Python wrapper when tma_descriptor_args follow the new layout or when im2col is used.
Apply the following updates:
- Add a Python IM2COL init template.
- Parse the new layout (no args[1:]).
- Branch on create_im2col and call cuTensorMapEncodeIm2col.
*** a/tilelang/jit/adapter/wrapper.py
@@
TMA_DESC_INIT_FUNC_PY = """
@@
"""
+
+TMA_IM2COL_DESC_INIT_FUNC_PY = """
+\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
+\t{0}_tensorRank = {2}
+\t{0}_globalAddress = {3}.data_ptr()
+\t{0}_globalDim = [{4}]
+\t{0}_globalStride = [{5}][1:]
+\t{0}_elementStrides = [{6}]
+\t{0}_lowerCorner = [{7}]
+\t{0}_upperCorner = [{8}]
+\t{0}_channelsPerPixel = {9}
+\t{0}_pixelsPerColumn = {10}
+\t{0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({11})
+\t{0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({12})
+\t{0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({13})
+\t{0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({14})
+
+\tres, {0} = cuda.bindings.driver.cuTensorMapEncodeIm2col(
+\t\t{0}_type,
+\t\t{0}_tensorRank,
+\t\t{0}_globalAddress,
+\t\t{0}_globalDim,
+\t\t{0}_globalStride,
+\t\t{0}_lowerCorner,
+\t\t{0}_upperCorner,
+\t\t{0}_channelsPerPixel,
+\t\t{0}_pixelsPerColumn,
+\t\t{0}_elementStrides,
+\t\t{0}_interleave,
+\t\t{0}_swizzle,
+\t\t{0}_l2Promotion,
+\t\t{0}_oobFill,
+\t)
+\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
+\t\traise RuntimeError(f"Failed to initialize the TMA im2col descriptor {0}: {res}")
+"""
@@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
@@
- _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
+ # New layout: [tma_create_str, <desc or dummy>, dtype, tensor_rank, globalAddress, ...]
+ tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args
+ is_im2col = (getattr(tma_create_str, "value", tma_create_str) == "__tvm_tensormap_create_im2col")
@@
- # Calculate required length for remaining_args
- # 4 groups of tensor_rank size + 4 parameters
- expected_args_len = 4 * tensor_rank + 4
- if len(remaining_args) < expected_args_len:
- raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
- f"expected {expected_args_len} for tensor_rank {tensor_rank}")
-
- # Extract dimensions and strides using list slicing
- global_dim = remaining_args[:tensor_rank]
- global_stride = remaining_args[tensor_rank:2 * tensor_rank]
- box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
- element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
-
- global_dim = [str(i) for i in global_dim]
- global_stride = [str(i) for i in global_stride]
- box_dim = [str(i) for i in box_dim]
- element_strides = [str(i) for i in element_strides]
-
- # Extract remaining parameters
- try:
- interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
- tensor_rank + 4]
- except ValueError as e:
- raise ValueError(
- "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
- ) from e
-
- tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
- handle_name, dtype, tensor_rank, globalAddress,
- ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_dim)),
- ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_stride)),
- ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", box_dim)),
- ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})",
- element_strides)), interleave, swizzle, l2Promotion, oobFill)
+ if not is_im2col:
+ expected_args_len = 4 * tensor_rank + 4
+ if len(remaining_args) < expected_args_len:
+ raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}")
+ global_dim = remaining_args[:tensor_rank]
+ global_stride = remaining_args[tensor_rank:2 * tensor_rank]
+ box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
+ element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
+ interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * tensor_rank + 4]
+ tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
+ handle_name, dtype, tensor_rank, globalAddress,
+ ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_dim))),
+ ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_stride))),
+ ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, box_dim))),
+ ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, element_strides))),
+ interleave, swizzle, l2Promotion, oobFill)
+ else:
+ if tensor_rank < 4:
+ raise ValueError(f"Invalid tensor_rank for im2col: {tensor_rank}. Expected >= 4")
+ expected_args_len = 5 * tensor_rank + 2
+ if len(remaining_args) < expected_args_len:
+ raise ValueError(f"Insufficient remaining args for im2col: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}")
+ global_dim = remaining_args[:tensor_rank]
+ global_stride = remaining_args[tensor_rank:2 * tensor_rank]
+ element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank]
+ lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
+ upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
+ smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2]
+ tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format(
+ handle_name, dtype, tensor_rank, globalAddress,
+ ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_dim))),
+ ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_stride))),
+ ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, element_strides))),
+ ", ".join(map(str, lower_corner)),
+ ", ".join(map(str, upper_corner)),
+ smem_box_channel, # channelsPerPixel
+ smem_box_pixel, # pixelsPerColumn
+ interleave, swizzle, l2Promotion, oobFill)Please confirm the NVRTC bindings expose cuTensorMapEncodeIm2col. If not, we’ll need to add it to the Python driver first.
🏁 Script executed:
#!/bin/bash
rg -n "cuTensorMapEncodeIm2col|cuTensorMapEncodeTiled|TLNVRTCSourceWrapper" -C2Length of output: 4322
NVRTC path will break: lacks im2col support and uses old argument layout.
The NVRTC backend (TLNVRTCSourceWrapper) was not updated alongside the CUDA host changes. Two critical gaps:
- Layout mismatch: NVRTC still assumes the old argument order (
args[1:]), but the new CUDA path now parsestma_create_strfirst, without theargs[1:]offset. - Missing im2col: Only
cuTensorMapEncodeTiledis available in the Python templates; im2col support is absent.
The proposed patch correctly adds TMA_IM2COL_DESC_INIT_FUNC_PY and updates argument parsing for both branches. However, verify that cuda.bindings provides full coverage of and 1:1 access to the CUDA host APIs from Python, which should include cuTensorMapEncodeIm2col. If this binding is unavailable in the installed version, it must be added to the Python driver first before the patch can work.
🧰 Tools
🪛 Ruff (0.14.1)
442-442: Avoid specifying long messages outside the exception class
(TRY003)
448-449: Avoid specifying long messages outside the exception class
(TRY003)
471-473: Avoid specifying long messages outside the exception class
(TRY003)
483-484: Avoid specifying long messages outside the exception class
(TRY003)
509-511: Avoid specifying long messages outside the exception class
(TRY003)
🤖 Prompt for AI Agents
tilelang/jit/adapter/wrapper.py lines 434-517: NVRTC wrapper is out of sync with
CUDA host changes — it still assumes the old args[1:] layout and lacks im2col
support; update the NVRTC/TLNVRTCSourceWrapper parsing to consume tma_create_str
as the first element (same as the CUDA path) instead of offsetting by 1, add the
im2col handling path using the new TMA_IM2COL_DESC_INIT_FUNC_PY template
(mirroring the non-im2col branch’s slicing/validation logic but with the
im2col-specific slices and final 6 parameters), and ensure the Python
cuda.bindings layer exposes cuTensorMapEncodeIm2col (or add that binding) before
enabling the im2col branch so the template maps 1:1 to the host API.
|
@codex review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Codex Review
Here are some automated review suggestions for this pull request.
ℹ️ About Codex in GitHub
Your team has set up Codex to review pull requests in this repo. Reviews are triggered when you
- Open a pull request for review
- Mark a draft as ready
- Comment "@codex review".
If Codex has suggestions, it will comment; otherwise it will react with 👍.
Codex can also answer questions or update the PR. Try commenting "@codex address that feedback".
| if not isinstance(tensor_rank, int) or tensor_rank <= 0: | ||
| raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer") | ||
|
|
||
| # Calculate required length for remaining_args | ||
| expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters | ||
| if len(remaining_args) < expected_args_len: | ||
| raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " | ||
| f"expected {expected_args_len} for tensor_rank {tensor_rank}") | ||
|
|
||
| # Extract dimensions and strides using list slicing | ||
| global_dim = remaining_args[:tensor_rank] | ||
| global_stride = remaining_args[tensor_rank:2 * tensor_rank] | ||
| box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] | ||
| element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] | ||
|
|
||
| global_dim = [self._pythonic_expr(i) for i in global_dim] | ||
| global_stride = [self._pythonic_expr(i) for i in global_stride] | ||
| box_dim = [self._pythonic_expr(i) for i in box_dim] | ||
| element_strides = [self._pythonic_expr(i) for i in element_strides] | ||
|
|
||
| # Extract remaining parameters | ||
| try: | ||
| interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * | ||
| tensor_rank + 4] | ||
| interleave = self._pythonic_expr(interleave) | ||
| swizzle = self._pythonic_expr(swizzle) | ||
| l2Promotion = self._pythonic_expr(l2Promotion) | ||
| oobFill = self._pythonic_expr(oobFill) | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" | ||
| ) from e | ||
| if not is_img2col: | ||
| # Calculate required length for remaining_args | ||
| expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters | ||
| if len(remaining_args) < expected_args_len: | ||
| raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " | ||
| f"expected {expected_args_len} for tensor_rank {tensor_rank}") | ||
|
|
||
| # Extract dimensions and strides using list slicing | ||
| global_dim = remaining_args[:tensor_rank] | ||
| global_stride = remaining_args[tensor_rank:2 * tensor_rank] | ||
| box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank] | ||
| element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank] | ||
|
|
||
| global_dim = [self._pythonic_expr(i) for i in global_dim] | ||
| global_stride = [self._pythonic_expr(i) for i in global_stride] | ||
| box_dim = [self._pythonic_expr(i) for i in box_dim] | ||
| element_strides = [self._pythonic_expr(i) for i in element_strides] | ||
|
|
||
| # Extract remaining parameters | ||
| try: | ||
| interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * | ||
| tensor_rank + 4] | ||
| interleave = self._pythonic_expr(interleave) | ||
| swizzle = self._pythonic_expr(swizzle) | ||
| l2Promotion = self._pythonic_expr(l2Promotion) | ||
| oobFill = self._pythonic_expr(oobFill) | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)" | ||
| ) from e | ||
|
|
||
| tma_descripter_init += TMA_DESC_INIT_FUNC.format( | ||
| handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), | ||
| ",".join(global_stride), ",".join(box_dim), ",".join(element_strides), | ||
| interleave, swizzle, l2Promotion, oobFill) | ||
| else: | ||
| # Calculate required length for remaining_args | ||
| expected_args_len = 5 * tensor_rank + 2 | ||
| if len(remaining_args) < expected_args_len: | ||
| raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, " | ||
| f"expected {expected_args_len} for tensor_rank {tensor_rank}") | ||
|
|
||
| # Extract dimensions and strides using list slicing | ||
| global_dim = remaining_args[:tensor_rank] | ||
| global_stride = remaining_args[tensor_rank:2 * tensor_rank] | ||
| element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank] | ||
| lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2] | ||
| upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4] | ||
| global_dim = [self._pythonic_expr(i) for i in global_dim] | ||
| global_stride = [self._pythonic_expr(i) for i in global_stride] | ||
| element_strides = [self._pythonic_expr(i) for i in element_strides] | ||
| lower_corner = [self._pythonic_expr(i) for i in lower_corner] | ||
| upper_corner = [self._pythonic_expr(i) for i in upper_corner] | ||
|
|
||
| # Extract remaining parameters | ||
| try: | ||
| smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[ | ||
| 5 * tensor_rank - 4:5 * tensor_rank + 2] | ||
| smem_box_pixel = self._pythonic_expr(smem_box_pixel) | ||
| smem_box_channel = self._pythonic_expr(smem_box_channel) | ||
| interleave = self._pythonic_expr(interleave) | ||
| swizzle = self._pythonic_expr(swizzle) | ||
| l2Promotion = self._pythonic_expr(l2Promotion) | ||
| oobFill = self._pythonic_expr(oobFill) | ||
| except ValueError as e: | ||
| raise ValueError( | ||
| "Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)" | ||
| ) from e | ||
|
|
||
| tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format( | ||
| handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim), | ||
| ",".join(global_stride), ",".join(element_strides), ",".join(lower_corner), | ||
| ",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle, | ||
| l2Promotion, oobFill) | ||
|
|
||
| tma_descripter_init += TMA_DESC_INIT_FUNC.format(handle_name, dtype, tensor_rank, | ||
| globalAddress, ",".join(global_dim), | ||
| ",".join(global_stride), | ||
| ",".join(box_dim), | ||
| ",".join(element_strides), interleave, | ||
| swizzle, l2Promotion, oobFill) | ||
| return tma_descripter_init |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
NVRTC path still emits tiled descriptor for IM2COL
The new branch in generate_tma_descriptor_args only teaches the CUDA driver wrapper how to parse IM2COL descriptors and dispatch cuTensorMapEncodeIm2col, but the NVRTC backend (TLNVRTCSourceWrapper.generate_tma_descriptor_args around lines 840‑894) was left untouched. That path still assumes the old argument layout and always calls cuTensorMapEncodeTiled, so any kernel compiled via NVRTC that produces a __tvm_tensormap_create_im2col descriptor will either raise the "Insufficient remaining args" ValueError or encode the descriptor with the wrong CUDA API. This means IM2COL kernels will fail when using the NVRTC runtime, which is a common execution mode.
Useful? React with 👍 / 👎.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we can support in future.
This pull request introduces support for TMA Im2Col descriptor initialization and ensures that all relevant code paths and IR transformations handle the new
tma_load_im2coloperation. Additionally, it improves user feedback in the convolution example.Support for TMA Im2Col descriptors:
TMA_IM2COL_DESC_INIT_FUNCstring template for initializing TMA Im2Col descriptors intilelang/jit/adapter/wrapper.py.generate_tma_descriptor_argsto correctly parse and initialize Im2Col descriptors, including validation and extraction of Im2Col-specific parameters. [1] [2]IR and transformation updates for Im2Col:
src/transform/inject_tma_barrier.ccto recognize and handle bothtma_loadandtma_load_im2coloperations, ensuring correct barrier injection and rewriting. [1] [2] [3]User experience improvement:
Summary by CodeRabbit
New Features
Improvements