Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/convolution/example_convolution.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def main(argv=None):
out_c = kernel(a, b)
ref_c = ref_program(S, P, D)(a, b)
torch.testing.assert_close(out_c, ref_c, rtol=1e-2, atol=1e-2)
print("All checks passed.✅")


if __name__ == "__main__":
Expand Down
9 changes: 5 additions & 4 deletions src/transform/inject_tma_barrier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class TmaExpectTxRewriter : public IRMutatorWithAnalyzer {
}

PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
auto arg0 = op->args[0].as<Call>();
bool is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
Expand Down Expand Up @@ -203,7 +203,7 @@ class TmaBarrierCollector : public IRVisitorWithAnalyzer {

void VisitStmt_(const EvaluateNode *op) final {
if (const auto *call = op->value.as<CallNode>()) {
if (call->op.same_as(tma_load())) {
if (call->op.same_as(tma_load()) || call->op.same_as(tma_load_im2col())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
} else if (call->op.same_as(mbarrier_expect_tx())) {
pending_tma_ops_.push_back(GetRef<Call>(call));
Expand Down Expand Up @@ -451,15 +451,16 @@ class TmaBarrierRewriter : public IRMutatorWithAnalyzer {
}

PrimExpr VisitExpr_(const CallNode *op) {
if (op->op.same_as(tma_load())) {
if (op->op.same_as(tma_load()) || op->op.same_as(tma_load_im2col())) {
// check this must be in the tma_op_to_barrier_id_
ICHECK(tma_op_to_barrier_id_.count(GetRef<Call>(op)))
<< "tma_load must be in the tma_op_to_barrier_id_";
auto barrier_id = tma_op_to_barrier_id_[GetRef<Call>(op)];
auto new_args = op->args;
auto arg0 = op->args[0].as<Call>();
auto is_1d_tma_load =
arg0 && !arg0.value()->op.same_as(create_tma_descriptor());
arg0 && !arg0.value()->op.same_as(create_tma_descriptor()) &&
!arg0.value()->op.same_as(create_tma_im2col_descriptor());
if (is_1d_tma_load) {
new_args.Set(2, barrier_id);
} else {
Expand Down
143 changes: 107 additions & 36 deletions tilelang/jit/adapter/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,35 @@ def call({}):
\t}}
"""

TMA_IM2COL_DESC_INIT_FUNC = """
\tCUtensorMap {0};
\tCUtensorMapDataType {0}_type= (CUtensorMapDataType){1};
\tcuuint32_t {0}_tensorRank= {2};
\tvoid *{0}_globalAddress= {3};
\tcuuint64_t {0}_globalDim[{2}]= {{{4}}};
\tcuuint64_t {0}_globalStride[{2}]= {{{5}}};
\tcuuint32_t {0}_elementStrides[{2}]= {{{6}}};
\tint {0}_lowerCorner[{2} - 2]= {{{7}}};
\tint {0}_upperCorner[{2} - 2]= {{{8}}};
\tcuuint32_t {0}_channelsPerPixel= {9};
\tcuuint32_t {0}_pixelsPerColumn= {10};
\tCUtensorMapInterleave {0}_interleave= (CUtensorMapInterleave){11};
\tCUtensorMapSwizzle {0}_swizzle= (CUtensorMapSwizzle){12};
\tCUtensorMapL2promotion {0}_l2Promotion= (CUtensorMapL2promotion){13};
\tCUtensorMapFloatOOBfill {0}_oobFill= (CUtensorMapFloatOOBfill){14};

\tCUresult {0}_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)(
&{0}, {0}_type, {0}_tensorRank, {0}_globalAddress, {0}_globalDim, {0}_globalStride + 1,
{0}_lowerCorner, {0}_upperCorner, {0}_channelsPerPixel, {0}_pixelsPerColumn, {0}_elementStrides, {0}_interleave, {0}_swizzle, {0}_l2Promotion, {0}_oobFill);

\tif ({0}_result != CUDA_SUCCESS) {{
\t\tstd::stringstream ss;
\t\tss << "Error: Failed to initialize the TMA descriptor {0}";
\t\tsnprintf(error_buf, ERROR_BUF_SIZE, "%s", ss.str().c_str());
\t\treturn -1;
\t}}
"""

TMA_DESC_INIT_FUNC_PY = """
\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
\t{0}_tensorRank = {2}
Expand Down Expand Up @@ -401,50 +430,92 @@ def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str],
if len(args) < 3:
raise ValueError(
f"TMA descriptor args too short: {len(args)} elements, expected at least 3")
_, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]

tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args

is_img2col = (tma_create_str.value == "__tvm_tensormap_create_im2col")
dtype = self._pythonic_expr(dtype)
tensor_rank = int(self._pythonic_expr(tensor_rank))

# Validate tensor_rank
if not isinstance(tensor_rank, int) or tensor_rank <= 0:
raise ValueError(f"Invalid tensor_rank: {tensor_rank}. Must be a positive integer")

# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]

global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]

# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e
if not is_img2col:
# Calculate required length for remaining_args
expected_args_len = 4 * tensor_rank + 4 # 4 groups of tensor_rank size + 4 parameters
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]

global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
box_dim = [self._pythonic_expr(i) for i in box_dim]
element_strides = [self._pythonic_expr(i) for i in element_strides]

# Extract remaining parameters
try:
interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
tensor_rank + 4]
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
) from e

tma_descripter_init += TMA_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(box_dim), ",".join(element_strides),
interleave, swizzle, l2Promotion, oobFill)
else:
# Calculate required length for remaining_args
expected_args_len = 5 * tensor_rank + 2
if len(remaining_args) < expected_args_len:
raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
f"expected {expected_args_len} for tensor_rank {tensor_rank}")

# Extract dimensions and strides using list slicing
global_dim = remaining_args[:tensor_rank]
global_stride = remaining_args[tensor_rank:2 * tensor_rank]
element_strides = remaining_args[2 * tensor_rank:3 * tensor_rank]
lower_corner = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
upper_corner = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
global_dim = [self._pythonic_expr(i) for i in global_dim]
global_stride = [self._pythonic_expr(i) for i in global_stride]
element_strides = [self._pythonic_expr(i) for i in element_strides]
lower_corner = [self._pythonic_expr(i) for i in lower_corner]
upper_corner = [self._pythonic_expr(i) for i in upper_corner]

# Extract remaining parameters
try:
smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[
5 * tensor_rank - 4:5 * tensor_rank + 2]
smem_box_pixel = self._pythonic_expr(smem_box_pixel)
smem_box_channel = self._pythonic_expr(smem_box_channel)
interleave = self._pythonic_expr(interleave)
swizzle = self._pythonic_expr(swizzle)
l2Promotion = self._pythonic_expr(l2Promotion)
oobFill = self._pythonic_expr(oobFill)
except ValueError as e:
raise ValueError(
"Failed to unpack the final 6 TMA parameters (smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill)"
) from e

tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC.format(
handle_name, dtype, tensor_rank, globalAddress, ",".join(global_dim),
",".join(global_stride), ",".join(element_strides), ",".join(lower_corner),
",".join(upper_corner), smem_box_channel, smem_box_pixel, interleave, swizzle,
l2Promotion, oobFill)
Comment on lines +434 to +517
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

NVRTC path will mis-parse/new layout and lacks im2col support.

You changed the descriptor-arg layout (tma_create_str first) and added im2col handling for the CUDA host path, but TLNVRTCSourceWrapper.generate_tma_descriptor_args still:

  • assumes the old layout (uses args[1:]) and
  • only supports cuTensorMapEncodeTiled, not cuTensorMapEncodeIm2col.

This will break the NVRTC backend and Python wrapper when tma_descriptor_args follow the new layout or when im2col is used.

Apply the following updates:

  • Add a Python IM2COL init template.
  • Parse the new layout (no args[1:]).
  • Branch on create_im2col and call cuTensorMapEncodeIm2col.
*** a/tilelang/jit/adapter/wrapper.py
@@
 TMA_DESC_INIT_FUNC_PY = """
@@
 """
+
+TMA_IM2COL_DESC_INIT_FUNC_PY = """
+\t{0}_type = cuda.bindings.driver.CUtensorMapDataType({1})
+\t{0}_tensorRank = {2}
+\t{0}_globalAddress = {3}.data_ptr()
+\t{0}_globalDim = [{4}]
+\t{0}_globalStride = [{5}][1:]
+\t{0}_elementStrides = [{6}]
+\t{0}_lowerCorner = [{7}]
+\t{0}_upperCorner = [{8}]
+\t{0}_channelsPerPixel = {9}
+\t{0}_pixelsPerColumn = {10}
+\t{0}_interleave = cuda.bindings.driver.CUtensorMapInterleave({11})
+\t{0}_swizzle = cuda.bindings.driver.CUtensorMapSwizzle({12})
+\t{0}_l2Promotion = cuda.bindings.driver.CUtensorMapL2promotion({13})
+\t{0}_oobFill = cuda.bindings.driver.CUtensorMapFloatOOBfill({14})
+
+\tres, {0} = cuda.bindings.driver.cuTensorMapEncodeIm2col(
+\t\t{0}_type,
+\t\t{0}_tensorRank,
+\t\t{0}_globalAddress,
+\t\t{0}_globalDim,
+\t\t{0}_globalStride,
+\t\t{0}_lowerCorner,
+\t\t{0}_upperCorner,
+\t\t{0}_channelsPerPixel,
+\t\t{0}_pixelsPerColumn,
+\t\t{0}_elementStrides,
+\t\t{0}_interleave,
+\t\t{0}_swizzle,
+\t\t{0}_l2Promotion,
+\t\t{0}_oobFill,
+\t)
+\tif res != cuda.bindings.driver.CUresult.CUDA_SUCCESS:
+\t\traise RuntimeError(f"Failed to initialize the TMA im2col descriptor {0}: {res}")
+"""
@@ class TLNVRTCSourceWrapper(TLCUDASourceWrapper):
     def generate_tma_descriptor_args(self, desc_name_map: Dict[str, str]) -> str:
@@
-            _, dtype, tensor_rank, globalAddress, *remaining_args = args[1:]
+            # New layout: [tma_create_str, <desc or dummy>, dtype, tensor_rank, globalAddress, ...]
+            tma_create_str, _, dtype, tensor_rank, globalAddress, *remaining_args = args
+            is_im2col = (getattr(tma_create_str, "value", tma_create_str) == "__tvm_tensormap_create_im2col")
@@
-            # Calculate required length for remaining_args
-            # 4 groups of tensor_rank size + 4 parameters
-            expected_args_len = 4 * tensor_rank + 4
-            if len(remaining_args) < expected_args_len:
-                raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, "
-                                 f"expected {expected_args_len} for tensor_rank {tensor_rank}")
-
-            # Extract dimensions and strides using list slicing
-            global_dim = remaining_args[:tensor_rank]
-            global_stride = remaining_args[tensor_rank:2 * tensor_rank]
-            box_dim = remaining_args[2 * tensor_rank:3 * tensor_rank]
-            element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
-
-            global_dim = [str(i) for i in global_dim]
-            global_stride = [str(i) for i in global_stride]
-            box_dim = [str(i) for i in box_dim]
-            element_strides = [str(i) for i in element_strides]
-
-            # Extract remaining parameters
-            try:
-                interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 *
-                                                                           tensor_rank + 4]
-            except ValueError as e:
-                raise ValueError(
-                    "Failed to unpack the final 4 TMA parameters (interleave, swizzle, l2Promotion, oobFill)"
-                ) from e
-
-            tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
-                handle_name, dtype, tensor_rank, globalAddress,
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_dim)),
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", global_stride)),
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", box_dim)),
-                ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})",
-                              element_strides)), interleave, swizzle, l2Promotion, oobFill)
+            if not is_im2col:
+                expected_args_len = 4 * tensor_rank + 4
+                if len(remaining_args) < expected_args_len:
+                    raise ValueError(f"Insufficient remaining args: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}")
+                global_dim      = remaining_args[:tensor_rank]
+                global_stride   = remaining_args[tensor_rank:2 * tensor_rank]
+                box_dim         = remaining_args[2 * tensor_rank:3 * tensor_rank]
+                element_strides = remaining_args[3 * tensor_rank:4 * tensor_rank]
+                interleave, swizzle, l2Promotion, oobFill = remaining_args[4 * tensor_rank:4 * tensor_rank + 4]
+                tma_descripter_init += TMA_DESC_INIT_FUNC_PY.format(
+                    handle_name, dtype, tensor_rank, globalAddress,
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_dim))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_stride))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, box_dim))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, element_strides))),
+                    interleave, swizzle, l2Promotion, oobFill)
+            else:
+                if tensor_rank < 4:
+                    raise ValueError(f"Invalid tensor_rank for im2col: {tensor_rank}. Expected >= 4")
+                expected_args_len = 5 * tensor_rank + 2
+                if len(remaining_args) < expected_args_len:
+                    raise ValueError(f"Insufficient remaining args for im2col: got {len(remaining_args)}, expected {expected_args_len} for tensor_rank {tensor_rank}")
+                global_dim       = remaining_args[:tensor_rank]
+                global_stride    = remaining_args[tensor_rank:2 * tensor_rank]
+                element_strides  = remaining_args[2 * tensor_rank:3 * tensor_rank]
+                lower_corner     = remaining_args[3 * tensor_rank:4 * tensor_rank - 2]
+                upper_corner     = remaining_args[4 * tensor_rank - 2:5 * tensor_rank - 4]
+                smem_box_pixel, smem_box_channel, interleave, swizzle, l2Promotion, oobFill = remaining_args[5 * tensor_rank - 4:5 * tensor_rank + 2]
+                tma_descripter_init += TMA_IM2COL_DESC_INIT_FUNC_PY.format(
+                    handle_name, dtype, tensor_rank, globalAddress,
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_dim))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint64_t({x})", map(str, global_stride))),
+                    ", ".join(map(lambda x: f"cuda.bindings.driver.cuuint32_t({x})", map(str, element_strides))),
+                    ", ".join(map(str, lower_corner)),
+                    ", ".join(map(str, upper_corner)),
+                    smem_box_channel,  # channelsPerPixel
+                    smem_box_pixel,    # pixelsPerColumn
+                    interleave, swizzle, l2Promotion, oobFill)

Please confirm the NVRTC bindings expose cuTensorMapEncodeIm2col. If not, we’ll need to add it to the Python driver first.


🏁 Script executed:

#!/bin/bash
rg -n "cuTensorMapEncodeIm2col|cuTensorMapEncodeTiled|TLNVRTCSourceWrapper" -C2

Length of output: 4322


NVRTC path will break: lacks im2col support and uses old argument layout.

The NVRTC backend (TLNVRTCSourceWrapper) was not updated alongside the CUDA host changes. Two critical gaps:

  1. Layout mismatch: NVRTC still assumes the old argument order (args[1:]), but the new CUDA path now parses tma_create_str first, without the args[1:] offset.
  2. Missing im2col: Only cuTensorMapEncodeTiled is available in the Python templates; im2col support is absent.

The proposed patch correctly adds TMA_IM2COL_DESC_INIT_FUNC_PY and updates argument parsing for both branches. However, verify that cuda.bindings provides full coverage of and 1:1 access to the CUDA host APIs from Python, which should include cuTensorMapEncodeIm2col. If this binding is unavailable in the installed version, it must be added to the Python driver first before the patch can work.

🧰 Tools
🪛 Ruff (0.14.1)

442-442: Avoid specifying long messages outside the exception class

(TRY003)


448-449: Avoid specifying long messages outside the exception class

(TRY003)


471-473: Avoid specifying long messages outside the exception class

(TRY003)


483-484: Avoid specifying long messages outside the exception class

(TRY003)


509-511: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
tilelang/jit/adapter/wrapper.py lines 434-517: NVRTC wrapper is out of sync with
CUDA host changes — it still assumes the old args[1:] layout and lacks im2col
support; update the NVRTC/TLNVRTCSourceWrapper parsing to consume tma_create_str
as the first element (same as the CUDA path) instead of offsetting by 1, add the
im2col handling path using the new TMA_IM2COL_DESC_INIT_FUNC_PY template
(mirroring the non-im2col branch’s slicing/validation logic but with the
im2col-specific slices and final 6 parameters), and ensure the Python
cuda.bindings layer exposes cuTensorMapEncodeIm2col (or add that binding) before
enabling the im2col branch so the template maps 1:1 to the host API.


tma_descripter_init += TMA_DESC_INIT_FUNC.format(handle_name, dtype, tensor_rank,
globalAddress, ",".join(global_dim),
",".join(global_stride),
",".join(box_dim),
",".join(element_strides), interleave,
swizzle, l2Promotion, oobFill)
return tma_descripter_init
Comment on lines 441 to 519

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge NVRTC path still emits tiled descriptor for IM2COL

The new branch in generate_tma_descriptor_args only teaches the CUDA driver wrapper how to parse IM2COL descriptors and dispatch cuTensorMapEncodeIm2col, but the NVRTC backend (TLNVRTCSourceWrapper.generate_tma_descriptor_args around lines 840‑894) was left untouched. That path still assumes the old argument layout and always calls cuTensorMapEncodeTiled, so any kernel compiled via NVRTC that produces a __tvm_tensormap_create_im2col descriptor will either raise the "Insufficient remaining args" ValueError or encode the descriptor with the wrong CUDA API. This means IM2COL kernels will fail when using the NVRTC runtime, which is a common execution mode.

Useful? React with 👍 / 👎.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can support in future.


def parse_source_information(self):
Expand Down