Skip to content

Conversation

@kurisu6912
Copy link
Collaborator

@kurisu6912 kurisu6912 commented Oct 24, 2025

  • tilelang frontend v2

Summary by CodeRabbit

  • New Features

    • v2 language surface: AST-based DSL quoting/mutation, Builder/IR generation, prim_func/macro decorators, dtype system, and compilation utilities with disk cache.
    • Typed JIT API and tooling: JITImpl/JITKernel generics, parallel compile, and redesigned autotune wrapper.
  • Bug Fixes

    • Corrected example kernel indexing and improved kernel name resolution.
  • Chores

    • Updated tracked TVM submodule and switched tqdm backend.
  • Tests

    • New/updated tests for dtypes, JIT par-compile, and IRModule-returning fixtures.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 24, 2025

Walkthrough

Adds a new TileLang v2 (AST DSL, Builder, dtypes, utils), re-exports prim_func/macro from v2, introduces JIT/autotuner generics and par-compile, updates tests and an example indexing fix, tweaks Metal kernel name resolution, and updates the tracked TVM submodule pointer.

Changes

Cohort / File(s) Summary
Example fix
examples/gdn/example_chunk_o_bwd.py
Adjust kernel unpacking var name and indexing: (i_k, i_v)(i_k, i_v_1) and use h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1]. No other control-flow or shape changes.
Top-level re-exports
tilelang/language/__init__.py, tilelang/language/v2/__init__.py
Re-route prim_func/macro exports to v2; re-export PrimFunc and dtypes symbols from v2.
AST-based DSL layer
tilelang/language/v2/ast.py
New AST DSL utilities: span helpers, QuoteVisitor, quote helpers, operator mappings/evaluators, BaseBuilder, DSLMutator, and mutate decorator.
IR builder & macros/prim_func
tilelang/language/v2/builder.py
New Builder runtime with thread-local context, frame classes, unwrap helpers, IRGenerator/PrimFunc/Macro types, Torch-tensor arg binding, and macro/prim_func decorator factories.
Dtype system
tilelang/language/v2/dtypes.py
Add dtype conversion registry, many scalar/vector/FP8 aliases, get_tvm_dtype, and enhanced tvm.DataType behavior and mappings.
Compilation & utils
tilelang/language/v2/utils.py
Disk-backed compile/cache (disk_compile), source normalization, closure capture inspection, get_ast, and get_compiled_object.
JIT typing & par-compile
tilelang/jit/__init__.py, tilelang/jit/kernel.py, tilelang/autotuner/tuner.py
Introduce generics/ParamSpec typing, JITImpl/JITKernel generics, par_compile parallel compilation, and AutoTuneImpl for autotuning and caching.
JIT adapter — kernel naming
tilelang/jit/adapter/torch/metal.py
Use tir.PrimFunc.global_symbol when available for kernel base name; fallback to __name__ otherwise; error messages use resolved base name.
Tests — IRModule construction updates
testing/python/transform/test_tilelang_transform_layout_inference.py, testing/python/transform/test_tilelang_transform_lower_tile_op.py
Replace @ir_module class Before/After with def before()/after() returning tvm.IRModule({'main': main}); update call sites.
Tests — typed loop bound
testing/python/transform/test_tilelang_transform_legalize_safe_memory_access.py
Replace T.dynamic('num_tokens') with T.Var('num_tokens', 'int32') for loop bound declaration.
Tests — new/updated tests
testing/python/language/test_tilelang_language_dtype.py, testing/python/jit/test_tilelang_jit_parcompile.py, testing/python/transform/test_tilelang_transform_multi_version_buffer.py
Add dtype coverage tests, par-compile JIT test; remove some local type annotations; adapt tests to new module patterns.
Misc tooling
tilelang/__init__.py
Switch from tqdm import tqdmfrom tqdm.auto import tqdm.
Third-party submodule
3rdparty/tvm
Tracked submodule pointer updated from 5bf17a39cda9b6 (submodule commit change only).

Sequence Diagram(s)

sequenceDiagram
    participant User as User Code
    participant Decorator as @mutate / @prim_func / @macro
    participant Utils as v2.utils
    participant AST as v2.ast (DSLMutator)
    participant Builder as v2.builder
    participant TVM as TVM IRBuilder

    rect rgb(245,250,255)
    Note over User,Decorator: Decoration & capture
    User->>Decorator: apply decorator (source func)
    Decorator->>Utils: get_ast / inspect_function_capture
    Decorator->>AST: transform AST with DSLMutator
    end

    rect rgb(245,255,245)
    Note over AST,Builder: AST → DSL → IR
    AST->>Builder: emit DSL constructs (binds, control flow)
    Builder->>Builder: manage frames, unwrap exprs, bind args
    Builder->>TVM: generate PrimFunc / IR nodes
    end

    rect rgb(255,250,240)
    Note over Builder,User: Return generated object
    Builder->>Decorator: return IRGenerator / PrimFunc / Macro
    Decorator->>User: symbol replaced with generated object
    end
Loading

Estimated code review effort

🎯 5 (Critical) | ⏱️ ~120 minutes

Potential focal points:

  • DSLMutator correctness, span propagation, and AST→DSL translation (tilelang/language/v2/ast.py)
  • Builder frame lifecycle, thread-local handling, arg binding and Torch tensor interop (tilelang/language/v2/builder.py)
  • Dtype mapping correctness and TVM/C-FFI conversions (tilelang/language/v2/dtypes.py)
  • Autotuner/JIT generics, par_compile and caching semantics (tilelang/jit/*, tilelang/autotuner/tuner.py)
  • Disk compilation, closure capture, and tests relying on new IRModule construction
  • Verify the 3rdparty/tvm submodule update in CI

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999
  • tzj-fxz

Poem

🐇 I hopped through AST and bound each name,

I cached the bytes and called the builder's frame,
I mapped the dtypes and unwrapped each arg,
One index fixed — the kernel found its flag,
A tiny hop, a tidy game.

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 6.06% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[Language] Initial version of tilelang frontend v2" directly and clearly describes the primary change of this pull request. The changeset introduces a comprehensive new v2 frontend for TileLang, including multiple new modules (ast.py, builder.py, dtypes.py, utils.py), extensive public API exports, updated infrastructure in the jit and autotuner layers, and corresponding test files. The title accurately captures this main objective by explicitly naming the feature being introduced—"tilelang frontend v2"—with the "Initial version" qualifier appropriately reflecting the scope. The title is specific, concise, and avoids vague terminology.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

🧹 Nitpick comments (12)
examples/gdn/example_chunk_o_bwd.py (1)

259-260: Index unflattening is correct; minor readability nit.

The quotient/remainder split is right for a shape (block_DK, block_DV). Consider divmod for clarity and to avoid two passes. Also confirm the nearby scalar reduction path still produces correct results with current pass configs.

-    for i_kv in T.Parallel(block_DK * block_DV):
-        i_k, i_v_1 = i_kv // block_DV, i_kv % block_DV
-        dg_last_fragment[i_kv] = h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1]
+    for i_kv in T.Parallel(block_DK * block_DV):
+        i_k, i_v_1 = divmod(i_kv, block_DV)
+        dg_last_fragment[i_kv] = h_shared[i_k, i_v_1] * dh_shared[i_k, i_v_1]
tilelang/language/v2/dtypes.py (4)

63-109: Remove invalid req/rne and rely on eq/ne only.

Python has no reflected equality methods; these won’t be called. They add noise and could confuse readers.

 class dtype:
@@
-    def __req__(self, other: AnyDType):
-        if isinstance(other, str):
-            return str.__eq__(self.name, other)
-        if other in self.__cvt:
-            return str.__eq__(self.name, self.__cvt[other])
-        return NotImplemented
@@
-    def __rne__(self, other: AnyDType):
-        if isinstance(other, str):
-            return str.__ne__(self.name, other)
-        if other in self.__cvt:
-            return str.__ne__(self.name, self.__cvt[other])
-        return NotImplemented

117-123: Ensure get_tvm_dtype returns tvm.DataType as annotated.

When value is an ir.Type, you currently return it unchanged, violating the signature.

 def get_tvm_dtype(value: AnyDType) -> tvm.DataType:
-    if isinstance(value, (tvm.DataType, ir.Type)):
-        return value
+    if isinstance(value, tvm.DataType):
+        return value
+    if isinstance(value, ir.Type):
+        # Convert PrimType/etc. to tvm.DataType
+        return tvm.DataType(str(value))
     if isinstance(value, dtype):
         return value.get_tvm_dtype()
     return dtype(value).get_tvm_dtype()

10-47: Confirm torch is a hard dependency; if optional, guard import and mappings.

This module imports torch unconditionally and builds _dtype_cvt using torch symbols. If torch isn’t installed, import fails at module import time.

-from tvm import ir
-import torch
+from tvm import ir
+try:
+    import torch  # noqa: F401
+    _HAS_TORCH = True
+except Exception:  # pragma: no cover
+    _HAS_TORCH = False
@@
-_dtype_cvt = [
-    (None, 'handle', ctypes.c_long, 'long'),  # use long to repr void*
-    (bool, 'bool', ctypes.c_bool, 'bool'),
-    (int, 'int32', ctypes.c_int32, 'int'),
-    (float, 'float32', ctypes.c_float, 'float'),
-    (torch.short, 'int16', ctypes.c_int16, 'short'),
-    (torch.int, 'int32', ctypes.c_int32, 'int'),
-    (torch.long, 'int64', ctypes.c_int64, 'long long'),
-    (torch.half, 'float16', None, None),
-    (torch.float, 'float32', ctypes.c_float, 'float'),
-    (torch.double, 'float64', ctypes.c_double, 'double'),
-    ...
-]
+_dtype_cvt = [
+    (None, 'handle', ctypes.c_long, 'long'),  # use long to repr void*
+    (bool, 'bool', ctypes.c_bool, 'bool'),
+    (int, 'int32', ctypes.c_int32, 'int'),
+    (float, 'float32', ctypes.c_float, 'float'),
+]
+if _HAS_TORCH:
+    _dtype_cvt += [
+        (torch.short, 'int16', ctypes.c_int16, 'short'),
+        (torch.int, 'int32', ctypes.c_int32, 'int'),
+        (torch.long, 'int64', ctypes.c_int64, 'long long'),
+        (torch.half, 'float16', None, None),
+        (torch.float, 'float32', ctypes.c_float, 'float'),
+        (torch.double, 'float64', ctypes.c_double, 'double'),
+        # pytype, tvm dtype str, ctypes, cffi
+        (torch.bool, 'bool', ctypes.c_bool, 'bool'),
+        (torch.int8, 'int8', ctypes.c_int8, 'char'),
+        (torch.int16, 'int16', ctypes.c_int16, 'short'),
+        (torch.int32, 'int32', ctypes.c_int32, 'int'),
+        (torch.int64, 'int64', ctypes.c_int64, 'long long'),
+        (torch.uint8, 'uint8', ctypes.c_uint8, 'unsigned char'),
+        (torch.uint16, 'uint16', ctypes.c_uint16, 'unsigned short'),
+        (torch.uint32, 'uint32', ctypes.c_uint32, 'unsigned int'),
+        (torch.uint64, 'uint64', ctypes.c_uint64, 'unsigned long long'),
+        (torch.float16, 'float16', None, None),
+        (torch.float32, 'float32', ctypes.c_float, 'float'),
+        (torch.float64, 'float64', ctypes.c_double, 'double'),
+        (torch.float8_e4m3fn, 'float8_e4m3fn', None, None),
+        (torch.float8_e4m3fnuz, 'float8_e4m3fnuz', None, None),
+        (torch.float8_e5m2, 'float8_e5m2', None, None),
+        (torch.float8_e5m2fnuz, 'float8_e5m2fnuz', None, None),
+        (torch.float8_e8m0fnu, 'float8_e8m0fnu', None, None),
+        (torch.bfloat16, 'bfloat16', None, None),
+    ]

110-112: Harden call error path for missing FFI entry.

Provide an actionable error if tb_ffi lacks the dtype constructor.

-    def __call__(self, expr=None, is_size_var: bool = False) -> tir.Var:
-        return getattr(tb_ffi, self.name.title())(expr, is_size_var)
+    def __call__(self, expr=None, is_size_var: bool = False) -> tir.Var:
+        try:
+            ctor = getattr(tb_ffi, self.name.title())
+        except AttributeError as e:
+            raise NotImplementedError(f"FFI constructor for dtype '{self.name}' not found") from e
+        return ctor(expr, is_size_var)
tilelang/language/v2/__init__.py (1)

1-1: Clean re-export; drop unused noqa and add explicit all.

Keeps lints quiet and clarifies public API.

-from .builder import prim_func, macro  # noqa: F401
+from .builder import prim_func, macro
+__all__ = ["prim_func", "macro"]
tilelang/language/__init__.py (1)

12-13: Remove unused noqa flags; confirm intentional API switch to v2.

Drop the redundant noqa comments. Also, please confirm this public API change (importing prim_func/macro from v2) is intentional and covered by tests, since behavior may differ from the tir version.

-# from .tir import prim_func, macro,  # noqa: F401
-from .v2 import prim_func, macro  # noqa: F401
+# from .tir import prim_func, macro,
+from .v2 import prim_func, macro
tilelang/language/v2/utils.py (1)

24-30: Typo: _remove_leading_ident_remove_leading_indent.

Pure readability; no behavior change.

-def _remove_leading_ident(source: str):
+def _remove_leading_indent(source: str):
@@
-    source = _remove_leading_ident(source)
+    source = _remove_leading_indent(source)
tilelang/language/v2/builder.py (2)

333-341: arg(): missing annotation yields confusing KeyError.

Provide a clear TypeError when a parameter lacks an annotation (since Builder.arg requires it).

Apply:

@@ def arg(self, name, value):
-        else:
-            annot = self.arg_annot[name]
+        else:
+            try:
+                annot = self.arg_annot[name]
+            except KeyError:
+                raise TypeError(f"Missing type annotation for argument '{name}' in prim_func") from None
             if callable(annot):
                 annot = annot()
             return tir.arg(name, annot)

27-29: unwrap_expr(): only normalizes IntImm(int32).

Normalize all integer/boolean immediates to Python scalars for consistent downstream handling (and to avoid relying on the PrimExpr path).

Apply:

-    elif isinstance(expr, tir.IntImm) and expr.dtype == 'int32':
-        expr = expr.value
+    elif isinstance(expr, tir.IntImm):
+        if expr.dtype in ('int8','int16','int32','int64','uint8','uint16','uint32','uint64','bool'):
+            expr = expr.value
tilelang/language/v2/ast.py (2)

15-19: ast_get_span() type annotation doesn’t allow None but returns None.

Update the return type to reflect reality.

Apply:

-def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]:
+def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int] | None:

176-178: Remove unused noqa.

# noqa: B027 is unnecessary here.

Apply:

-    def eval(self, val: Any):  # noqa: B027
+    def eval(self, val: Any):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 50e789d and b5f36ad.

📒 Files selected for processing (7)
  • examples/gdn/example_chunk_o_bwd.py (1 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/v2/__init__.py (1 hunks)
  • tilelang/language/v2/ast.py (1 hunks)
  • tilelang/language/v2/builder.py (1 hunks)
  • tilelang/language/v2/dtypes.py (1 hunks)
  • tilelang/language/v2/utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (4)
tilelang/language/v2/__init__.py (1)
tilelang/language/v2/builder.py (4)
  • prim_func (107-111)
  • prim_func (374-383)
  • macro (114-123)
  • macro (361-371)
tilelang/language/v2/builder.py (3)
tilelang/language/kernel.py (1)
  • KernelLaunchFrame (95-226)
tilelang/language/v2/ast.py (24)
  • BaseBuilder (160-230)
  • eval_op (87-114)
  • mutate (482-489)
  • ctx_if (165-166)
  • ctx_then (168-170)
  • ctx_else (172-174)
  • eval (176-177)
  • ctx_for (179-180)
  • ctx_continue (182-183)
  • ctx_break (185-186)
  • ctx_while (188-190)
  • bind (192-193)
  • get_parent_locals (162-163)
  • assign_slice (195-196)
  • aug_assign (198-199)
  • aug_assign_slice (201-202)
  • boolop (204-209)
  • ifexp (211-212)
  • ret (214-215)
  • ctx_with (217-218)
  • assert_expr (220-221)
  • rval (223-224)
  • arg (226-227)
  • override (229-230)
tilelang/language/tir/op.py (1)
  • if_then_else (2907-2937)
tilelang/language/__init__.py (2)
tilelang/language/v2/builder.py (4)
  • prim_func (107-111)
  • prim_func (374-383)
  • macro (114-123)
  • macro (361-371)
tilelang/language/tir/entry.py (2)
  • prim_func (10-60)
  • macro (66-117)
tilelang/language/v2/ast.py (2)
tilelang/language/v2/builder.py (19)
  • ctx_if (148-155)
  • ctx_then (157-164)
  • ctx_else (166-173)
  • eval (175-190)
  • ctx_for (192-202)
  • ctx_continue (204-205)
  • ctx_break (207-208)
  • ctx_while (210-211)
  • bind (213-226)
  • assign_slice (246-250)
  • aug_assign (252-258)
  • aug_assign_slice (260-264)
  • ifexp (278-284)
  • ret (286-305)
  • ctx_with (307-311)
  • assert_expr (313-318)
  • rval (320-331)
  • arg (333-340)
  • override (342-345)
tilelang/language/v2/utils.py (3)
  • get_ast (76-83)
  • get_compiled_object (89-106)
  • inspect_function_capture (56-73)
🪛 Ruff (0.14.1)
tilelang/language/v2/__init__.py

1-1: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/v2/utils.py

39-39: Avoid specifying long messages outside the exception class

(TRY003)


91-91: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


92-92: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


103-103: Avoid specifying long messages outside the exception class

(TRY003)


105-105: Use of exec detected

(S102)

tilelang/language/v2/builder.py

37-37: Avoid specifying long messages outside the exception class

(TRY003)


116-118: Avoid specifying long messages outside the exception class

(TRY003)


190-190: Avoid specifying long messages outside the exception class

(TRY003)


198-200: Avoid specifying long messages outside the exception class

(TRY003)


205-205: Avoid specifying long messages outside the exception class

(TRY003)


208-208: Avoid specifying long messages outside the exception class

(TRY003)


210-210: Unused method argument: cond

(ARG002)


211-211: Avoid specifying long messages outside the exception class

(TRY003)


256-256: Prefer TypeError exception for invalid type

(TRY004)


256-256: Avoid specifying long messages outside the exception class

(TRY003)


274-274: Avoid specifying long messages outside the exception class

(TRY003)


324-326: Avoid specifying long messages outside the exception class

(TRY003)


345-345: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/__init__.py

12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/v2/ast.py

114-114: Avoid specifying long messages outside the exception class

(TRY003)


157-157: Avoid specifying long messages outside the exception class

(TRY003)


176-176: Unused noqa directive (unused: B027)

Remove unused noqa directive

(RUF100)


192-192: Unused method argument: name

(ARG002)


209-209: Avoid specifying long messages outside the exception class

(TRY003)


223-223: Unused method argument: name

(ARG002)


226-226: Unused method argument: name

(ARG002)


276-276: Prefer TypeError exception for invalid type

(TRY004)


276-276: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/language/v2/dtypes.py

71-72: Avoid specifying long messages outside the exception class

(TRY003)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (1)
tilelang/language/v2/builder.py (1)

313-317: tir.Assert is available and correctly used—no action needed.

The API tir.Assert(condition, message) returning AssertFrame is documented in TVM's ir_builder module, and the code at builder.py:316 correctly invokes it with matching argument types (PrimExpr and str). No version-specific issues with AssertStmt or alternative builder helpers were found in the codebase.

Comment on lines +398 to +406
node = self.generic_visit(node)
all_args = node.args.posonlyargs + node.args.args
if node.args.vararg is not None:
all_args += node.args.vararg
all_args += node.args.kwonlyargs
stmts = []
for arg in all_args:
name = arg.arg
if arg.annotation is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

FunctionDef argument collection: vararg concatenation bug and missing kwarg.

  • all_args += node.args.vararg raises TypeError (vararg is an ast.arg, not list).
  • **kwargs (node.args.kwarg) is ignored.

Apply:

@@     def visit_FunctionDef(self, node: ast.FunctionDef):
-        all_args = node.args.posonlyargs + node.args.args
-        if node.args.vararg is not None:
-            all_args += node.args.vararg
-        all_args += node.args.kwonlyargs
+        all_args = list(node.args.posonlyargs) + list(node.args.args)
+        if node.args.vararg is not None:
+            all_args.append(node.args.vararg)
+        all_args += list(node.args.kwonlyargs)
+        if node.args.kwarg is not None:
+            all_args.append(node.args.kwarg)
🤖 Prompt for AI Agents
In tilelang/language/v2/ast.py around lines 398 to 406, the code incorrectly
attempts to concatenate node.args.vararg (an ast.arg) as if it were a list and
also omits node.args.kwarg; change the vararg handling to append the single
ast.arg (e.g., use all_args.append(node.args.vararg) when not None) and likewise
include the kwarg by appending node.args.kwarg when it is not None so both *args
and **kwargs are treated as individual arg objects in all_args.

Comment on lines +106 to +113
@contextmanager
def prim_func(self, name):
thread_local_storage.builder = self
with self.ir_builder, self.with_frame(tir.prim_func()):
tir.func_name(name)
yield

@contextmanager
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Thread-local builder is not restored after prim_func context.

thread_local_storage.builder is set but never restored, leaking across nested builders/threads.

Apply:

@@
-    def prim_func(self, name):
-        thread_local_storage.builder = self
-        with self.ir_builder, self.with_frame(tir.prim_func()):
-            tir.func_name(name)
-            yield
+    def prim_func(self, name):
+        prev = getattr(thread_local_storage, "builder", None)
+        thread_local_storage.builder = self
+        try:
+            with self.ir_builder, self.with_frame(tir.prim_func()):
+                tir.func_name(name)
+                yield
+        finally:
+            thread_local_storage.builder = prev
🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 106 to 113, the context manager
prim_func sets thread_local_storage.builder but never restores the previous
value, leaking state across nested contexts/threads; modify prim_func to save
the previous builder into a local variable before setting
thread_local_storage.builder, yield inside a try block, and in a finally block
restore thread_local_storage.builder back to the saved previous value (ensuring
None if no previous value existed) so the thread-local builder is always
restored after the context exits.

Comment on lines 175 to 189
def eval(self, val: Any):
val = unwrap_expr(val)
if val is None:
pass
elif isinstance(val, tir.frame.IRBuilderFrame):
self.enter_frame(val)
elif isinstance(val, PrimExpr):
tir.evaluate(val)
elif isinstance(val, (int, bool)):
self.enter_frame(tir.evaluate(tvm.tir.const(val)))
elif isinstance(val, str):
pass
elif isinstance(val, tvm.tir.stmt.BufferStore):
self.enter_frame(tir.buffer_store(val.buffer, val.value, val.indices, val.predicate))
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

PrimExpr evaluate is not emitted (frame not entered).

For PrimExprs you call tir.evaluate(val) but don’t enter the frame, so the statement is dropped. Fix by entering the frame (same as ints/bools).

Apply:

@@ def eval(self, val: Any):
-        elif isinstance(val, PrimExpr):
-            tir.evaluate(val)
+        elif isinstance(val, PrimExpr):
+            self.enter_frame(tir.evaluate(val))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def eval(self, val: Any):
val = unwrap_expr(val)
if val is None:
pass
elif isinstance(val, tir.frame.IRBuilderFrame):
self.enter_frame(val)
elif isinstance(val, PrimExpr):
tir.evaluate(val)
elif isinstance(val, (int, bool)):
self.enter_frame(tir.evaluate(tvm.tir.const(val)))
elif isinstance(val, str):
pass
elif isinstance(val, tvm.tir.stmt.BufferStore):
self.enter_frame(tir.buffer_store(val.buffer, val.value, val.indices, val.predicate))
else:
def eval(self, val: Any):
val = unwrap_expr(val)
if val is None:
pass
elif isinstance(val, tir.frame.IRBuilderFrame):
self.enter_frame(val)
elif isinstance(val, PrimExpr):
self.enter_frame(tir.evaluate(val))
elif isinstance(val, (int, bool)):
self.enter_frame(tir.evaluate(tvm.tir.const(val)))
elif isinstance(val, str):
pass
elif isinstance(val, tvm.tir.stmt.BufferStore):
self.enter_frame(tir.buffer_store(val.buffer, val.value, val.indices, val.predicate))
else:
🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 175 to 189, PrimExpr values call
tir.evaluate(val) but the resulting statement frame is never entered so the
evaluation is dropped; change the PrimExpr branch to enter the frame with the
evaluated statement (i.e., call self.enter_frame(tir.evaluate(val))) similar to
how ints/bools are handled, ensuring the evaluated PrimExpr is wrapped via
enter_frame.

Comment on lines 252 to 258
def aug_assign(self, op, target, aug_value):
if isinstance(target, Buffer) and target.scope() == 'local.var':
tir.buffer_store(target, eval_op(op, target, aug_value), 0)
if isinstance(target, Buffer):
raise RuntimeError("Augmented assignment is not supported for Buffer")
else:
return super().aug_assign(op, target, aug_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Aug-assignment on local scalar Buffer uses Buffer (not BufferLoad) and falls through to error.

  • You compute eval_op(op, target, aug_value) with target as a Buffer, not its scalar value.
  • After storing, control falls through and raises the “Buffer not supported” error.

Use BufferLoad([0]) and return the computed value to match the Name path semantics:

@@ def aug_assign(self, op, target, aug_value):
-        if isinstance(target, Buffer) and target.scope() == 'local.var':
-            tir.buffer_store(target, eval_op(op, target, aug_value), 0)
-        if isinstance(target, Buffer):
-            raise RuntimeError("Augmented assignment is not supported for Buffer")
-        else:
-            return super().aug_assign(op, target, aug_value)
+        if isinstance(target, Buffer) and target.scope() == 'local.var':
+            cur = tir.BufferLoad(target, [0])
+            new_val = eval_op(op, cur, aug_value)
+            self.enter_frame(tir.buffer_store(target, new_val, 0))
+            return new_val
+        if isinstance(target, Buffer):
+            raise TypeError("Augmented assignment is not supported for Buffer")
+        return super().aug_assign(op, target, aug_value)
🧰 Tools
🪛 Ruff (0.14.1)

256-256: Prefer TypeError exception for invalid type

(TRY004)


256-256: Avoid specifying long messages outside the exception class

(TRY003)

🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 252-258, the code computes
eval_op using the Buffer object itself and then falls through to raise the
"Buffer not supported" error; change the local-scalar Buffer branch to load its
scalar value (use BufferLoad([0]) for the target), compute the new value via
eval_op(op, BufferLoad([0]), aug_value), call tir.buffer_store(target,
computed_value, 0), and then return computed_value so execution does not fall
through to the error path and semantics match the Name path.

Comment on lines +10 to +22
def disk_compile(source, name):
cache_dir = env.TILELANG_CACHE_DIR
if cache_dir is not None:
import os
save_dir = os.path.join(cache_dir, "py-cache")
os.makedirs(save_dir, exist_ok=True)
hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8]
path = os.path.join(save_dir, f"{name}.{hash_sfx}.py")
with open(path, 'w') as f:
f.write(source)
linecache.cache[path] = (len(source), None, source.splitlines(), path)
return compile(source, path, "exec")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Bug: path may be undefined when TILELANG_CACHE_DIR is None.

If no cache dir is set, path is referenced before assignment. Provide a stable pseudo-filename and keep linecache consistent.

 def disk_compile(source, name):
     cache_dir = env.TILELANG_CACHE_DIR
-    if cache_dir is not None:
+    if cache_dir is not None:
         import os
         save_dir = os.path.join(cache_dir, "py-cache")
         os.makedirs(save_dir, exist_ok=True)
         hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8]
         path = os.path.join(save_dir, f"{name}.{hash_sfx}.py")
-        with open(path, 'w') as f:
+        with open(path, 'w', encoding='utf-8') as f:
             f.write(source)
+    else:
+        # Stable pseudo filename for in-memory compiled sources
+        hash_sfx = sha256(source.encode('utf-8')).hexdigest()[:8]
+        path = f"<tilelang:{name}:{hash_sfx}>"
     linecache.cache[path] = (len(source), None, source.splitlines(), path)
     return compile(source, path, "exec")
🤖 Prompt for AI Agents
In tilelang/language/v2/utils.py around lines 10 to 22, the variable `path` can
be referenced before assignment when TILELANG_CACHE_DIR is None; change the
function to always set a stable pseudo-filename when no cache dir is provided
(e.g. path = f"<{name}.py>") so that both linecache.cache[path] and
compile(source, path, "exec") use the same name; keep the existing behavior of
writing to disk only when cache dir is present, and ensure path is defined
outside the if/else so linecache and compile are consistent.

Comment on lines 89 to 106
def get_compiled_object(source: str | ast.AST,
name: str,
filename: str = None,
globals: dict[str, Any] = None):
if isinstance(source, ast.AST):
assert filename is not None, "filename must be provided when source is an AST"
try:
if isinstance(source, ast.AST):
ast.fix_missing_locations(source)
compiled = compile(source, filename, 'exec')
else:
compiled = disk_compile(source, name)
except Exception as e:
source_str = source if isinstance(source, str) else ast.unparse(source)
raise RuntimeError(f'Failed to compile source for {name}:\n{source_str}') from e
locs = {}
exec(compiled, globals, locs)
return locs[name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix exec globals handling, avoid shadowing built-in, and tighten type hints.

exec requires a dict for globals; passing None raises TypeError. Also avoid naming the parameter globals, and annotate Optional types.

-from typing import Any, Callable, Literal
+from typing import Any, Callable, Literal, Optional
@@
-def get_compiled_object(source: str | ast.AST,
-                        name: str,
-                        filename: str = None,
-                        globals: dict[str, Any] = None):
+def get_compiled_object(source: str | ast.AST,
+                        name: str,
+                        filename: Optional[str] = None,
+                        globals_: Optional[dict[str, Any]] = None):
@@
-    try:
+    try:
         if isinstance(source, ast.AST):
             ast.fix_missing_locations(source)
             compiled = compile(source, filename, 'exec')
         else:
             compiled = disk_compile(source, name)
     except Exception as e:
         source_str = source if isinstance(source, str) else ast.unparse(source)
         raise RuntimeError(f'Failed to compile source for {name}:\n{source_str}') from e
-    locs = {}
-    exec(compiled, globals, locs)
+    locs: dict[str, Any] = {}
+    gbls: dict[str, Any] = {} if globals_ is None else dict(globals_)
+    gbls.setdefault('__builtins__', __builtins__)
+    exec(compiled, gbls, locs)
     return locs[name]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_compiled_object(source: str | ast.AST,
name: str,
filename: str = None,
globals: dict[str, Any] = None):
if isinstance(source, ast.AST):
assert filename is not None, "filename must be provided when source is an AST"
try:
if isinstance(source, ast.AST):
ast.fix_missing_locations(source)
compiled = compile(source, filename, 'exec')
else:
compiled = disk_compile(source, name)
except Exception as e:
source_str = source if isinstance(source, str) else ast.unparse(source)
raise RuntimeError(f'Failed to compile source for {name}:\n{source_str}') from e
locs = {}
exec(compiled, globals, locs)
return locs[name]
def get_compiled_object(source: str | ast.AST,
name: str,
filename: Optional[str] = None,
globals_: Optional[dict[str, Any]] = None):
if isinstance(source, ast.AST):
assert filename is not None, "filename must be provided when source is an AST"
try:
if isinstance(source, ast.AST):
ast.fix_missing_locations(source)
compiled = compile(source, filename, 'exec')
else:
compiled = disk_compile(source, name)
except Exception as e:
source_str = source if isinstance(source, str) else ast.unparse(source)
raise RuntimeError(f'Failed to compile source for {name}:\n{source_str}') from e
locs: dict[str, Any] = {}
gbls: dict[str, Any] = {} if globals_ is None else dict(globals_)
gbls.setdefault('__builtins__', __builtins__)
exec(compiled, gbls, locs)
return locs[name]
🧰 Tools
🪛 Ruff (0.14.1)

91-91: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


92-92: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


103-103: Avoid specifying long messages outside the exception class

(TRY003)


105-105: Use of exec detected

(S102)

🤖 Prompt for AI Agents
In tilelang/language/v2/utils.py around lines 89 to 106, the function should
rename the parameter `globals` to avoid shadowing the built-in, tighten type
hints to use Optional[dict[str, Any]] for the source and globals param, and
ensure exec always receives a dict: accept a parameter like `globals_dict:
Optional[dict[str, Any]] = None`, create a new globals mapping when None (e.g.
globals_dict = {"__builtins__": __builtins__} or globals_dict.copy() with
builtins preserved), and pass that mapping to exec along with a locals dict;
keep the AST filename assertion and error handling as-is but use the new
parameter name when calling exec and returning locs[name].

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 7

♻️ Duplicate comments (4)
tilelang/language/v2/ast.py (1)

438-461: FunctionDef argument collection: vararg concatenation bug and missing kwarg.

This issue was previously flagged. The code attempts to concatenate node.args.vararg (an ast.arg object) to a list using +=, which raises TypeError. Additionally, **kwargs (node.args.kwarg) is completely ignored.

Apply this diff to fix both issues:

     def visit_FunctionDef(self, node: ast.FunctionDef):
         node = self.generic_visit(node)
-        all_args = node.args.posonlyargs + node.args.args
+        all_args = list(node.args.posonlyargs) + list(node.args.args)
         if node.args.vararg is not None:
-            all_args += node.args.vararg
-        all_args += node.args.kwonlyargs
+            all_args.append(node.args.vararg)
+        all_args += list(node.args.kwonlyargs)
+        if node.args.kwarg is not None:
+            all_args.append(node.args.kwarg)
         stmts = []
tilelang/language/v2/builder.py (3)

114-120: Thread-local builder state is not restored.

The thread_local_storage.builder is set but never restored to its previous value, which can leak state across nested builders or threads.

Apply this diff to save and restore the previous builder:

     @contextmanager
     def prim_func(self, name):
+        prev = getattr(thread_local_storage, "builder", None)
         thread_local_storage.builder = self
-        with self.ir_builder, self.with_frame(tir.prim_func()):
-            tir.func_name(name)
-            yield
+        try:
+            with self.ir_builder, self.with_frame(tir.prim_func()):
+                tir.func_name(name)
+                yield
+        finally:
+            thread_local_storage.builder = prev

183-199: PrimExpr evaluation statement is not emitted.

On line 190, tir.evaluate(val) creates an evaluation frame but doesn't enter it with self.enter_frame(). This causes the evaluation to be dropped. Compare with lines 192 where int/bool values correctly use self.enter_frame().

Apply this diff:

         elif isinstance(val, PrimExpr):
-            tir.evaluate(val)
+            self.enter_frame(tir.evaluate(val))

288-294: Augmented assignment on local.var Buffer has multiple issues.

Three problems:

  1. Line 290 computes eval_op(op, target, aug_value) using the Buffer object instead of loading its scalar value
  2. Line 290 doesn't enter the frame for the store operation
  3. Missing return statement causes execution to fall through to line 292, raising an error even for valid local.var cases

Apply this diff:

     def aug_assign(self, op, target, aug_value):
         if isinstance(target, Buffer) and target.scope() == 'local.var':
-            tir.buffer_store(target, eval_op(op, target, aug_value), 0)
+            cur = tir.BufferLoad(target, [0])
+            new_val = eval_op(op, cur, aug_value)
+            self.enter_frame(tir.buffer_store(target, new_val, 0))
+            return new_val
         if isinstance(target, Buffer):
-            raise RuntimeError("Augmented assignment is not supported for Buffer")
+            raise TypeError("Augmented assignment is not supported for Buffer")
-        else:
-            return super().aug_assign(op, target, aug_value)
+        return super().aug_assign(op, target, aug_value)
🧹 Nitpick comments (4)
tilelang/language/v2/ast.py (1)

181-182: Remove unused noqa directive.

The # noqa: B027 directive is not needed as static analysis indicates it's unused.

Apply this diff:

-    def eval(self, val: Any):  # noqa: B027
+    def eval(self, val: Any):
         pass
tilelang/language/__init__.py (2)

12-12: Remove commented-out code.

The commented import should be removed rather than left in the codebase. If this import needs to be preserved for reference, consider documenting the migration in a separate changelog or migration guide.

Apply this diff:

-# from .tir import prim_func, macro,  # noqa: F401

13-13: Remove unnecessary noqa directive.

The # noqa: F401 directive is unnecessary as F401 is not an active linting violation.

Apply this diff:

-from .v2 import prim_func, macro  # noqa: F401
+from .v2 import prim_func, macro
tilelang/language/v2/builder.py (1)

102-102: Use explicit Optional type hint.

PEP 484 prohibits implicit Optional. The parameter should explicitly use Optional[dict[str, Any]].

Apply this diff:

+from typing import Optional
+
-    def __init__(self, arg_annot: dict[str, Any] = None):
+    def __init__(self, arg_annot: Optional[dict[str, Any]] = None):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b5f36ad and 09d8aec.

📒 Files selected for processing (4)
  • 3rdparty/tvm (1 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/v2/ast.py (1 hunks)
  • tilelang/language/v2/builder.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
tilelang/language/__init__.py (1)
tilelang/language/v2/builder.py (4)
  • prim_func (115-119)
  • prim_func (441-454)
  • macro (122-131)
  • macro (435-438)
tilelang/language/v2/ast.py (2)
tilelang/language/v2/builder.py (21)
  • boolop (302-312)
  • ctx_if (156-163)
  • ctx_then (165-172)
  • ctx_else (174-181)
  • eval (183-198)
  • ctx_for (200-210)
  • ctx_continue (212-213)
  • ctx_break (215-216)
  • ctx_while (218-219)
  • bind (221-250)
  • unwrap_value (252-259)
  • assign_slice (279-286)
  • aug_assign (288-294)
  • aug_assign_slice (296-300)
  • ifexp (314-320)
  • ret (322-341)
  • ctx_with (343-347)
  • assert_expr (349-354)
  • rval (356-367)
  • arg (369-379)
  • override (381-384)
tilelang/language/v2/utils.py (3)
  • get_ast (76-83)
  • get_compiled_object (89-106)
  • inspect_function_capture (56-73)
tilelang/language/v2/builder.py (4)
tilelang/language/kernel.py (1)
  • KernelLaunchFrame (95-226)
tilelang/language/v2/ast.py (25)
  • BaseBuilder (164-238)
  • eval_op (87-114)
  • mutate (523-530)
  • ctx_if (170-171)
  • ctx_then (173-175)
  • ctx_else (177-179)
  • eval (181-182)
  • ctx_for (184-185)
  • ctx_continue (187-188)
  • ctx_break (190-191)
  • ctx_while (193-195)
  • bind (197-198)
  • get_parent_locals (167-168)
  • unwrap_value (200-201)
  • assign_slice (203-204)
  • aug_assign (206-207)
  • aug_assign_slice (209-210)
  • boolop (212-217)
  • ifexp (219-220)
  • ret (222-223)
  • ctx_with (225-226)
  • assert_expr (228-229)
  • rval (231-232)
  • arg (234-235)
  • override (237-238)
tilelang/language/v2/dtypes.py (3)
  • get_tvm_dtype (113-114)
  • get_tvm_dtype (117-122)
  • dtype (63-114)
tilelang/language/ast/ir.py (7)
  • If (1096-1112)
  • Then (1115-1123)
  • Else (1126-1134)
  • evaluate (1319-1331)
  • buffer_store (1263-1300)
  • alloc_buffer (441-508)
  • target (1682-1713)
🪛 Ruff (0.14.1)
tilelang/language/__init__.py

12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/language/v2/ast.py

114-114: Avoid specifying long messages outside the exception class

(TRY003)


157-157: Avoid specifying long messages outside the exception class

(TRY003)


181-181: Unused noqa directive (unused: B027)

Remove unused noqa directive

(RUF100)


197-197: Unused method argument: name

(ARG002)


197-197: Unused method argument: annot

(ARG002)


203-203: Unused method argument: annot

(ARG002)


217-217: Avoid specifying long messages outside the exception class

(TRY003)


231-231: Unused method argument: name

(ARG002)


234-234: Unused method argument: name

(ARG002)


314-314: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

tilelang/language/v2/builder.py

41-41: Avoid specifying long messages outside the exception class

(TRY003)


102-102: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


124-126: Avoid specifying long messages outside the exception class

(TRY003)


198-198: Avoid specifying long messages outside the exception class

(TRY003)


206-208: Avoid specifying long messages outside the exception class

(TRY003)


213-213: Avoid specifying long messages outside the exception class

(TRY003)


216-216: Avoid specifying long messages outside the exception class

(TRY003)


218-218: Unused method argument: cond

(ARG002)


219-219: Avoid specifying long messages outside the exception class

(TRY003)


292-292: Prefer TypeError exception for invalid type

(TRY004)


292-292: Avoid specifying long messages outside the exception class

(TRY003)


310-310: Avoid specifying long messages outside the exception class

(TRY003)


360-362: Avoid specifying long messages outside the exception class

(TRY003)


379-379: Avoid specifying long messages outside the exception class

(TRY003)


384-384: Avoid specifying long messages outside the exception class

(TRY003)


387-387: Unused function argument: builder

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (8)
tilelang/language/v2/ast.py (6)

28-49: LGTM!

The QuoteVisitor implementation correctly handles name substitution and pass injection for AST transformation.


51-69: LGTM!

The quote helper functions provide a clean API for AST generation with proper assertions.


71-158: LGTM!

The operator utilities comprehensively handle all Python operators with correct implementations.


251-436: LGTM!

The DSLMutator visitor methods for control flow, assignments, and expressions correctly transform Python AST nodes into DSL builder calls.


463-517: LGTM!

The visitor methods for boolean operations, comparisons, conditionals, returns, context managers, assertions, and name references correctly transform Python constructs into DSL builder API calls.


523-530: LGTM!

The mutate decorator correctly orchestrates AST extraction, transformation, compilation, and source attachment.

3rdparty/tvm (1)

1-1: I need to verify the commit exists in the official TVM repository. Let me try fetching from the remote:

Verify the TVM submodule commit exists in the official repository and document the changes.

The commit 9cda9b611ba9d91a1d42b561767f40aba0afcd78 could not be verified in this environment due to sandbox limitations. Please confirm that this commit exists in the official Apache TVM repository (https://github.com/apache/tvm) and document what specific TVM changes are being pulled in with this update for the v2 TileLang frontend integration.

tilelang/language/v2/builder.py (1)

441-454: The decorator requires fully annotated parameters—verify this is the intended API contract.

Based on a comprehensive search of the codebase, all @T.prim_func decorated functions consistently use full type annotations (e.g., A: T.Tensor(...), B: T.Buffer(...)). No instances of unannotated parameters on decorated functions were found. The issue flagged is technically valid—line 450's ir_gen(builder)(**hints) would fail for missing annotations—but the codebase uniformly adheres to this requirement. This appears to be the intended API design, not a bug. Confirm with the team that annotating all parameters is the required contract for using @T.prim_func.

Comment on lines +15 to +18
def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix return type annotation to include None.

The function returns None when the AST node lacks span attributes, but the type hint indicates it always returns a tuple.

Apply this diff:

-def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]:
+def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int] | None:
     if not ast_has_span(ast):
         return None
     return tuple(getattr(ast, attr) for attr in _span_attrs)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int]:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
def ast_get_span(ast: ast.AST) -> tuple[int, int, int, int] | None:
if not ast_has_span(ast):
return None
return tuple(getattr(ast, attr) for attr in _span_attrs)
🤖 Prompt for AI Agents
In tilelang/language/v2/ast.py around lines 15 to 18, the function ast_get_span
currently types its return as tuple[int, int, int, int] but returns None when
the AST node has no span; update the type annotation to allow None (e.g.,
Optional[tuple[int, int, int, int]] with a from typing import Optional import,
or tuple[int, int, int, int] | None for Python 3.10+), and ensure any necessary
import is added at the top of the file.

Comment on lines +167 to +168
def get_parent_locals(self):
return inspect.currentframe().f_back.f_back.f_locals
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add null check for frame inspection.

inspect.currentframe() can return None in some Python implementations or execution contexts, which would cause an AttributeError.

Apply this diff:

 def get_parent_locals(self):
-    return inspect.currentframe().f_back.f_back.f_locals
+    frame = inspect.currentframe()
+    if frame is None or frame.f_back is None or frame.f_back.f_back is None:
+        raise RuntimeError("Unable to access parent frame")
+    return frame.f_back.f_back.f_locals
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def get_parent_locals(self):
return inspect.currentframe().f_back.f_back.f_locals
def get_parent_locals(self):
frame = inspect.currentframe()
if frame is None or frame.f_back is None or frame.f_back.f_back is None:
raise RuntimeError("Unable to access parent frame")
return frame.f_back.f_back.f_locals
🤖 Prompt for AI Agents
In tilelang/language/v2/ast.py around lines 167 to 168, get_parent_locals
currently calls inspect.currentframe().f_back.f_back.f_locals without guarding
for None; modify it to first capture frame = inspect.currentframe(), check that
frame is not None and that frame.f_back and frame.f_back.f_back exist, then
return their f_locals; if any are missing, return an empty dict (or None
consistently) to avoid AttributeError.

Comment on lines +237 to +238
def override(self, name: str):
return globals()[name]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Add error handling for unknown overrides.

Accessing globals()[name] will raise KeyError if the name doesn't exist in the global namespace.

Apply this diff:

 def override(self, name: str):
-    return globals()[name]
+    if name not in globals():
+        raise ValueError(f'Unknown override: {name}')
+    return globals()[name]
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def override(self, name: str):
return globals()[name]
def override(self, name: str):
if name not in globals():
raise ValueError(f'Unknown override: {name}')
return globals()[name]
🤖 Prompt for AI Agents
In tilelang/language/v2/ast.py around lines 237 to 238, the override method
currently returns globals()[name] which raises KeyError for unknown names;
modify it to check for the name in globals() (e.g., using "if name in
globals()") and if present return the value, otherwise raise a clear,
descriptive exception (for example raise NameError(f'Unknown override: {name}')
or ValueError) so callers get a meaningful error instead of an unhandled
KeyError.

Comment on lines +311 to +314
def _emit_assign_target(self,
target: ast.expr,
rval: ast.expr,
annot: ast.expr = None) -> list[ast.AST]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Fix implicit Optional in type annotation.

The annot parameter defaults to None but lacks Optional in its type annotation, violating PEP 484.

Apply this diff:

     def _emit_assign_target(self,
                             target: ast.expr,
                             rval: ast.expr,
-                            annot: ast.expr = None) -> list[ast.AST]:
+                            annot: ast.expr | None = None) -> list[ast.AST]:
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def _emit_assign_target(self,
target: ast.expr,
rval: ast.expr,
annot: ast.expr = None) -> list[ast.AST]:
def _emit_assign_target(self,
target: ast.expr,
rval: ast.expr,
annot: ast.expr | None = None) -> list[ast.AST]:
🧰 Tools
🪛 Ruff (0.14.1)

314-314: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)

🤖 Prompt for AI Agents
In tilelang/language/v2/ast.py around lines 311 to 314, the parameter `annot`
defaults to None but its annotation lacks Optional, so change the signature to
annotate `annot` as Optional[ast.expr] (or `ast.expr | None` if project uses
Python 3.10+), and add the corresponding `from typing import Optional` import
(or ensure `from __future__ import annotations` and appropriate imports) at the
top of the file so the type hint is valid per PEP 484.

Comment on lines +241 to +244
# ```
if is_var(orig_value) and not is_var(value):
tir.buffer_store(orig_value, value, 0)
return orig_value
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Buffer store statement is not emitted.

Line 243 calls tir.buffer_store() but doesn't enter the frame, so the store operation is dropped. Similar to the issue in eval(), statements must be wrapped with self.enter_frame().

Apply this diff:

         if is_var(orig_value) and not is_var(value):
-            tir.buffer_store(orig_value, value, 0)
+            self.enter_frame(tir.buffer_store(orig_value, value, 0))
             return orig_value
🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 241 to 244, the call to
tir.buffer_store(orig_value, value, 0) is emitted outside of a frame and thus
dropped; wrap the buffer_store statement inside a self.enter_frame(...) context
(same pattern used in eval()) so the store is actually emitted, then return
orig_value as before.

Comment on lines +283 to +284
if isinstance(lval, Buffer):
tir.buffer_store(lval, value, sl)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Buffer store statement is not emitted.

Line 284 calls tir.buffer_store() but doesn't enter the frame, causing the store to be dropped.

Apply this diff:

         if isinstance(lval, Buffer):
-            tir.buffer_store(lval, value, sl)
+            self.enter_frame(tir.buffer_store(lval, value, sl))
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if isinstance(lval, Buffer):
tir.buffer_store(lval, value, sl)
if isinstance(lval, Buffer):
self.enter_frame(tir.buffer_store(lval, value, sl))
🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 283-284, the code calls
tir.buffer_store(lval, value, sl) but does not emit or append the produced
statement into the current IR frame, so the store is dropped; change the call to
emit the statement into the builder (e.g., replace the bare
tir.buffer_store(...) call with self.emit(tir.buffer_store(lval, value, sl)) or
otherwise append/emit the returned node into the current frame) so the buffer
store is actually recorded.

Comment on lines +296 to +300
def aug_assign_slice(self, op, target, sl, aug_value):
if isinstance(target, Buffer):
tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl)
else:
return super().aug_assign_slice(op, target, sl, aug_value)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Buffer store statement is not emitted.

Line 298 calls tir.buffer_store() but doesn't enter the frame, so the augmented assignment is not emitted.

Apply this diff:

     def aug_assign_slice(self, op, target, sl, aug_value):
         if isinstance(target, Buffer):
-            tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl)
+            self.enter_frame(tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl))
         else:
             return super().aug_assign_slice(op, target, sl, aug_value)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def aug_assign_slice(self, op, target, sl, aug_value):
if isinstance(target, Buffer):
tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl)
else:
return super().aug_assign_slice(op, target, sl, aug_value)
def aug_assign_slice(self, op, target, sl, aug_value):
if isinstance(target, Buffer):
self.enter_frame(tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl))
else:
return super().aug_assign_slice(op, target, sl, aug_value)
🤖 Prompt for AI Agents
In tilelang/language/v2/builder.py around lines 296 to 300, the code calls
tir.buffer_store(...) directly which does not enter the builder's emission frame
and thus the augmented-assignment statement is never emitted; wrap the
buffer_store call with the builder's emit method (e.g., call
self.emit(tir.buffer_store(...)) so the statement is recorded in the current TIR
frame) and leave the else branch unchanged.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (7)
tilelang/language/v2/builder.py (7)

114-120: Thread-local builder is not restored after prim_func context.

This issue was already identified in previous reviews. The thread_local_storage.builder is set but never restored, which will leak state across nested builders or threads.


183-198: PrimExpr evaluate is not emitted (frame not entered).

This critical issue was already identified in previous reviews. Line 190 calls tir.evaluate(val) but doesn't enter the frame, so the evaluation statement is dropped. It should use self.enter_frame(tir.evaluate(val)) like the int/bool branch does on line 192.


248-250: Buffer store statement is not emitted.

This critical issue was already identified in previous reviews. Line 249 calls tir.buffer_store() but doesn't enter the frame, so the store is dropped. Should use self.enter_frame(tir.buffer_store(orig_value, value, 0)).


285-292: Buffer store statement is not emitted.

This critical issue was already identified in previous reviews. Line 290 calls tir.buffer_store() but doesn't enter the frame, so the store is dropped. Should use self.enter_frame(tir.buffer_store(lval, value, sl)).


294-300: Aug-assignment on local scalar Buffer has multiple issues.

This critical issue was already identified in previous reviews. The code:

  1. Uses the Buffer directly in eval_op instead of loading its value
  2. Doesn't use enter_frame to emit the store
  3. Falls through to raise an error even for the local.var case

302-306: Buffer store statement is not emitted.

This critical issue was already identified in previous reviews. Line 304 calls tir.buffer_store() but doesn't enter the frame, so the augmented assignment is not emitted. Should use self.enter_frame(tir.buffer_store(...)).


447-456: prim_func decorator has fragile argument passing.

This issue was already identified in previous reviews. Line 456 ir_gen(builder)(**hints) relies on __annotations__ order and completeness, which is fragile and breaks with unannotated parameters.

🧹 Nitpick comments (2)
tilelang/language/v2/builder.py (2)

102-103: Unused parameter: arg_annot is stored but never used.

The arg_annot parameter is stored in self.arg_annot but is never accessed or used anywhere in the Builder class. Consider removing it if not needed, or document its intended purpose if it's for future use.

Apply this diff if the parameter is not needed:

-    def __init__(self, arg_annot: dict[str, Any] = None):
-        self.arg_annot = arg_annot
+    def __init__(self):
         self.frames: list[AnyFrame] = []
         self.ir_builder = IRBuilder()
         self.name_inside_frame: dict[str, AnyFrame] = {}

Alternatively, if it's intended for future use, add a docstring or comment explaining its purpose.


449-451: In-place mutation of function annotations may cause issues.

The decorator mutates func.__annotations__ directly by calling callable annotations and replacing them with their return values. This could cause problems if:

  1. The decorator is applied multiple times to the same function
  2. Other code relies on the original annotations
  3. The same function object is used in different contexts

Consider creating a copy of the annotations before mutating:

 def prim_func(func: Callable[_P, _T]) -> PrimFunc[_P, _T]:
-    hints = func.__annotations__
+    hints = func.__annotations__.copy()
     for k in hints:
         if callable(hints[k]):
             hints[k] = hints[k]()
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 09d8aec and 4c75e85.

📒 Files selected for processing (1)
  • tilelang/language/v2/builder.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/v2/builder.py (4)
tilelang/language/kernel.py (1)
  • KernelLaunchFrame (95-226)
tilelang/language/v2/ast.py (25)
  • BaseBuilder (164-238)
  • eval_op (87-114)
  • mutate (523-530)
  • ctx_if (170-171)
  • ctx_then (173-175)
  • ctx_else (177-179)
  • eval (181-182)
  • ctx_for (184-185)
  • ctx_continue (187-188)
  • ctx_break (190-191)
  • ctx_while (193-195)
  • bind (197-198)
  • get_parent_locals (167-168)
  • unwrap_value (200-201)
  • assign_slice (203-204)
  • aug_assign (206-207)
  • aug_assign_slice (209-210)
  • boolop (212-217)
  • ifexp (219-220)
  • ret (222-223)
  • ctx_with (225-226)
  • assert_expr (228-229)
  • rval (231-232)
  • arg (234-235)
  • override (237-238)
tilelang/language/v2/dtypes.py (3)
  • get_tvm_dtype (113-114)
  • get_tvm_dtype (117-122)
  • dtype (63-114)
tilelang/language/ast/ir.py (9)
  • meta_var (1731-1750)
  • If (1096-1112)
  • Then (1115-1123)
  • Else (1126-1134)
  • evaluate (1319-1331)
  • buffer_store (1263-1300)
  • alloc_buffer (441-508)
  • LetStmt (880-908)
  • target (1682-1713)
🪛 Ruff (0.14.1)
tilelang/language/v2/builder.py

41-41: Avoid specifying long messages outside the exception class

(TRY003)


102-102: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


124-126: Avoid specifying long messages outside the exception class

(TRY003)


198-198: Avoid specifying long messages outside the exception class

(TRY003)


206-208: Avoid specifying long messages outside the exception class

(TRY003)


213-213: Avoid specifying long messages outside the exception class

(TRY003)


216-216: Avoid specifying long messages outside the exception class

(TRY003)


218-218: Unused method argument: cond

(ARG002)


219-219: Avoid specifying long messages outside the exception class

(TRY003)


298-298: Prefer TypeError exception for invalid type

(TRY004)


298-298: Avoid specifying long messages outside the exception class

(TRY003)


316-316: Avoid specifying long messages outside the exception class

(TRY003)


366-368: Avoid specifying long messages outside the exception class

(TRY003)


385-385: Avoid specifying long messages outside the exception class

(TRY003)


390-390: Avoid specifying long messages outside the exception class

(TRY003)


393-393: Unused function argument: builder

(ARG001)

🔇 Additional comments (4)
tilelang/language/v2/builder.py (4)

122-131: LGTM: Macro context properly manages state.

Unlike the prim_func context manager, this correctly saves and restores name_inside_frame state. The validation check for BoolOpFrame usage is also good defensive programming.


136-151: LGTM: Frame management logic is well-structured.

The frame stack management with find_frame_idx, enter_frame, and with_frame provides a clean abstraction for managing the IR builder frame hierarchy.


156-181: LGTM: Clever handling of compile-time vs. runtime conditionals.

The use of _has_if_frame sentinel to distinguish between PrimExpr (runtime) and constant (compile-time) conditionals is elegant. This allows the builder to optimize away constant conditions while properly emitting TIR for runtime conditions.


328-347: LGTM: Excellent validation and user guidance.

The check for returns inside control flow within macros is great defensive programming. The detailed error message with example code will significantly help users understand and fix the issue.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (6)
tilelang/language/v2/builder.py (6)

116-122: Thread-local builder is not restored after prim_func context.

thread_local_storage.builder is set but never restored, leaking state across nested builders or threads. This issue was flagged in a previous review and remains unresolved.

Apply this diff to save and restore the previous builder:

 @contextmanager
 def prim_func(self, name):
+    prev = getattr(thread_local_storage, "builder", None)
     thread_local_storage.builder = self
-    with self.ir_builder, self.with_frame(tir.prim_func()):
-        tir.func_name(name)
-        yield
+    try:
+        with self.ir_builder, self.with_frame(tir.prim_func()):
+            tir.func_name(name)
+            yield
+    finally:
+        thread_local_storage.builder = prev

185-200: PrimExpr evaluate statement is not emitted.

Line 192 calls tir.evaluate(val) but doesn't enter the frame, so the PrimExpr evaluation is dropped. This is inconsistent with line 194, which correctly uses self.enter_frame() for int/bool values. This critical issue was flagged in a previous review and remains unresolved.

Apply this diff:

     elif isinstance(val, PrimExpr):
-        tir.evaluate(val)
+        self.enter_frame(tir.evaluate(val))

250-252: Buffer store statement is not emitted.

Line 251 calls tir.buffer_store() but doesn't enter the frame, so the store operation is dropped. This critical issue was flagged in a previous review and remains unresolved.

Apply this diff:

     if is_var(orig_value) and not is_var(value):
-        tir.buffer_store(orig_value, value, 0)
+        self.enter_frame(tir.buffer_store(orig_value, value, 0))
         return orig_value

287-294: Buffer store statement is not emitted.

Line 292 calls tir.buffer_store() but doesn't enter the frame, causing the store to be dropped. This critical issue was flagged in a previous review and remains unresolved.

Apply this diff:

     if isinstance(lval, Buffer):
-        tir.buffer_store(lval, value, sl)
+        self.enter_frame(tir.buffer_store(lval, value, sl))

296-302: Multiple critical issues in augmented assignment for local.var Buffer.

This code has several problems flagged in previous reviews that remain unresolved:

  1. Line 298: tir.buffer_store() is not wrapped with self.enter_frame(), so the store is dropped
  2. Line 298: eval_op(op, target, aug_value) uses the Buffer object directly instead of loading its scalar value first via tir.BufferLoad(target, [0])
  3. Missing return after line 298, causing execution to fall through to line 299 and incorrectly raise the "not supported" error
  4. Line 300: Should use TypeError instead of RuntimeError for invalid type

Apply this diff:

 def aug_assign(self, op, target, aug_value):
     if isinstance(target, Buffer) and target.scope() == 'local.var':
-        tir.buffer_store(target, eval_op(op, target, aug_value), 0)
+        cur = tir.BufferLoad(target, [0])
+        new_val = eval_op(op, cur, aug_value)
+        self.enter_frame(tir.buffer_store(target, new_val, 0))
+        return new_val
     if isinstance(target, Buffer):
-        raise RuntimeError("Augmented assignment is not supported for Buffer")
-    else:
-        return super().aug_assign(op, target, aug_value)
+        raise TypeError("Augmented assignment is not supported for Buffer")
+    return super().aug_assign(op, target, aug_value)

304-308: Buffer store statement is not emitted.

Line 306 calls tir.buffer_store() but doesn't enter the frame, so the augmented assignment is not emitted. This critical issue was flagged in a previous review and remains unresolved.

Apply this diff:

 def aug_assign_slice(self, op, target, sl, aug_value):
     if isinstance(target, Buffer):
-        tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl)
+        self.enter_frame(tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl))
     else:
         return super().aug_assign_slice(op, target, sl, aug_value)
🧹 Nitpick comments (1)
tilelang/language/v2/builder.py (1)

104-104: Type hint should be explicit Optional.

Static analysis flagged implicit Optional. Update the type hint to be explicit.

Apply this diff:

-    def __init__(self, arg_annot: dict[str, Any] = None):
+    def __init__(self, arg_annot: dict[str, Any] | None = None):
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 4c75e85 and e3815c6.

📒 Files selected for processing (1)
  • tilelang/language/v2/builder.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/language/v2/builder.py (4)
tilelang/language/kernel.py (1)
  • KernelLaunchFrame (95-226)
tilelang/language/v2/ast.py (25)
  • BaseBuilder (164-238)
  • eval_op (87-114)
  • mutate (523-530)
  • ctx_if (170-171)
  • ctx_then (173-175)
  • ctx_else (177-179)
  • eval (181-182)
  • ctx_for (184-185)
  • ctx_continue (187-188)
  • ctx_break (190-191)
  • ctx_while (193-195)
  • bind (197-198)
  • get_parent_locals (167-168)
  • unwrap_value (200-201)
  • assign_slice (203-204)
  • aug_assign (206-207)
  • aug_assign_slice (209-210)
  • boolop (212-217)
  • ifexp (219-220)
  • ret (222-223)
  • ctx_with (225-226)
  • assert_expr (228-229)
  • rval (231-232)
  • arg (234-235)
  • override (237-238)
tilelang/language/v2/dtypes.py (3)
  • get_tvm_dtype (113-114)
  • get_tvm_dtype (117-122)
  • dtype (63-114)
tilelang/language/ast/ir.py (9)
  • meta_var (1731-1750)
  • If (1096-1112)
  • Then (1115-1123)
  • Else (1126-1134)
  • evaluate (1319-1331)
  • buffer_store (1263-1300)
  • alloc_buffer (441-508)
  • LetStmt (880-908)
  • target (1682-1713)
🪛 Ruff (0.14.1)
tilelang/language/v2/builder.py

43-43: Avoid specifying long messages outside the exception class

(TRY003)


104-104: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


126-128: Avoid specifying long messages outside the exception class

(TRY003)


200-200: Avoid specifying long messages outside the exception class

(TRY003)


208-210: Avoid specifying long messages outside the exception class

(TRY003)


215-215: Avoid specifying long messages outside the exception class

(TRY003)


218-218: Avoid specifying long messages outside the exception class

(TRY003)


220-220: Unused method argument: cond

(ARG002)


221-221: Avoid specifying long messages outside the exception class

(TRY003)


300-300: Prefer TypeError exception for invalid type

(TRY004)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


318-318: Avoid specifying long messages outside the exception class

(TRY003)


368-370: Avoid specifying long messages outside the exception class

(TRY003)


387-387: Avoid specifying long messages outside the exception class

(TRY003)


392-392: Avoid specifying long messages outside the exception class

(TRY003)


395-395: Unused function argument: builder

(ARG001)


464-471: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (1)
tilelang/language/v2/builder.py (1)

457-496: Previous issue appears to have been addressed.

A previous review comment flagged that the decorator was passing arguments incorrectly. The current implementation (lines 472-491) now correctly uses inspect.signature to build args and kwargs by parameter name and kind, then calls ir_gen.gen(builder)(*args, **kwargs). This matches the suggested fix from the previous review.

return tir.arg(name, buffer)


torch.Tensor.__tl_arg__ = __torch_tensor_tl_arg__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Global monkey-patch of torch.Tensor may cause issues.

Assigning __tl_arg__ to torch.Tensor globally modifies the external Torch library's behavior for all code in the process. This could lead to unexpected interactions with other libraries or test isolation issues.

Consider alternative approaches such as:

  • Using a registration system (e.g., a dict mapping types to handler functions)
  • Creating a wrapper class instead of modifying torch.Tensor directly
  • Documenting this side effect prominently if the monkey-patch is intentional

@kurisu6912 kurisu6912 changed the title [Feature] Initial version of tilelang frontend v2 [Language] Initial version of tilelang frontend v2 Oct 27, 2025
@kurisu6912
Copy link
Collaborator Author

Syntax for alloc var:

def foo():
    a: tl.float32 = ...
    b: tl.float32 = 0
    c: tl.float32

macro with return value:

@T.macro
def foo(a, b):
    return a > 0 ? b : 0

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

🧹 Nitpick comments (1)
tilelang/jit/adapter/torch/metal.py (1)

50-50: Optional: Consider extracting error message to a constant.

The error message is clearer now that it references func_name. However, Ruff suggests avoiding inline messages for maintainability.

This is a minor style improvement and entirely optional:

+KERNEL_NOT_FOUND_MSG = "no kernel with name"
+
     else:
-        raise AssertionError(f'no kernel with name {func_name}')
+        raise AssertionError(f'{KERNEL_NOT_FOUND_MSG} {func_name}')
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 6f69f02 and 20feef2.

📒 Files selected for processing (1)
  • tilelang/jit/adapter/torch/metal.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
tilelang/jit/adapter/torch/metal.py (2)
tilelang/language/v2/builder.py (1)
  • PrimFunc (410-419)
tilelang/language/ast/ir.py (1)
  • func_name (206-214)
🪛 Ruff (0.14.1)
tilelang/jit/adapter/torch/metal.py

50-50: Avoid specifying long messages outside the exception class

(TRY003)

🔇 Additional comments (1)
tilelang/jit/adapter/torch/metal.py (1)

30-34: I need to examine the actual metal.py file to understand the full context:

Based on my verification of the codebase patterns and the review comment provided, I can now generate the rewritten review:

The review comment conflates distinct concerns and requires code inspection to resolve.

The codebase shows consistent patterns across all kernel adapters (ctypes, nvrtc, cython):

  • PrimFunc handling: ALL adapters use direct attrs["global_symbol"] access without defensive checks, suggesting this is a TVM contract
  • IRModule handling: ALL other adapters simply assign the IRModule directly without accessing a __name__ attribute

The concern about line 33 (func_or_mod.__name__) accessing a missing attribute on tvm.IRModule is potentially valid but needs verification of the actual metal.py implementation. The web search and TVM documentation show no __name__ attribute on IRModule, but the comment claims to address lines 30-34 which may use a different code pattern than other adapters.

Without seeing the actual metal.py file content to confirm the exact code being reviewed, I cannot definitively determine whether:

  1. The code truly accesses IRModule.__name__ (which would be an error), or
  2. The code follows the standard pattern seen in other adapters (direct assignment)
  3. Whether attrs["global_symbol"] requires defensive access

1. isinstance(tl.float32, tvm.DataType) == True
2. Allow `tl.float32` as function annotations
3. Allow `tl.float32` as argument to be passed to `tl.alloc` or other functions
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/jit/kernel.py (1)

471-488: Blocker: export_library uses undefined self.rt_module.

This will raise AttributeError. Use the compiled runtime module on the artifact.

-        # Export the compiled kernel function to a shared library file.
-        self.rt_module.export_library(kernel_file)
+        # Export the compiled kernel function to a shared library file.
+        if not self.artifact or not getattr(self.artifact, "rt_mod", None):
+            raise RuntimeError("No runtime module to export; compile with a host runtime backend.")
+        self.artifact.rt_mod.export_library(kernel_file)
♻️ Duplicate comments (7)
tilelang/language/v2/builder.py (7)

116-120: Thread-local builder is not restored after prim_func context.

thread_local_storage.builder is set but never restored, potentially leaking state across nested builders or threads.

Apply:

-    def prim_func(self, name):
-        thread_local_storage.builder = self
-        with self.ir_builder, self.with_frame(tir.prim_func()):
-            tir.func_name(name)
-            yield
+    def prim_func(self, name):
+        prev = getattr(thread_local_storage, "builder", None)
+        thread_local_storage.builder = self
+        try:
+            with self.ir_builder, self.with_frame(tir.prim_func()):
+                tir.func_name(name)
+                yield
+        finally:
+            thread_local_storage.builder = prev

184-199: PrimExpr and BufferStore statements are not emitted.

Lines 191 and 197 call statement-producing functions but don't enter the frame, causing the statements to be dropped:

         elif isinstance(val, PrimExpr):
-            tir.evaluate(val)
+            self.enter_frame(tir.evaluate(val))
         elif isinstance(val, (int, bool)):
-            tir.evaluate(tvm.tir.const(val))
+            self.enter_frame(tir.evaluate(tvm.tir.const(val)))
         elif isinstance(val, str):
             pass
         elif isinstance(val, tvm.tir.stmt.BufferStore):
-            tir.buffer_store(val.buffer, val.value, val.indices, val.predicate)
+            self.enter_frame(tir.buffer_store(val.buffer, val.value, val.indices, val.predicate))

249-251: Buffer store statement is not emitted.

Line 250 calls tir.buffer_store() but doesn't enter the frame, so the store is dropped:

         if is_var(orig_value) and not is_var(value):
-            tir.buffer_store(orig_value, value, 0)
+            self.enter_frame(tir.buffer_store(orig_value, value, 0))
             return orig_value

287-294: Buffer store statement is not emitted.

Line 292 calls tir.buffer_store() but doesn't enter the frame, causing the store to be dropped:

         if isinstance(lval, Buffer):
-            tir.buffer_store(lval, value, sl)
+            self.enter_frame(tir.buffer_store(lval, value, sl))

296-302: Buffer store statement is not emitted in augmented assignment.

Line 298 calls tir.buffer_store() but doesn't enter the frame, so the augmented assignment is not emitted:

         if is_var(target):
-            tir.buffer_store(target, eval_op(op, target[0], aug_value), 0)
+            self.enter_frame(tir.buffer_store(target, eval_op(op, target[0], aug_value), 0))

304-308: Buffer store statement is not emitted in augmented slice assignment.

Line 306 calls tir.buffer_store() but doesn't enter the frame, so the augmented assignment is not emitted:

         if isinstance(target, Buffer):
-            tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl)
+            self.enter_frame(tir.buffer_store(target, eval_op(op, target[sl], aug_value), sl))

391-397: Global monkey-patch of torch.Tensor may cause issues.

Assigning __tl_arg__ to torch.Tensor globally modifies the external library's behavior for the entire process, potentially affecting other code or tests.

Consider alternatives such as:

  • A registration system (dict mapping types to handlers)
  • A wrapper class instead of modifying torch.Tensor directly
  • Documenting this side effect prominently if intentional
🧹 Nitpick comments (17)
tilelang/jit/kernel.py (2)

51-53: Unify execution_backend typing with runtime support (include "torch").

Type hints exclude "torch" while runtime paths accept it. Extend the Literal(s) to include "torch" to avoid type-check drift.

-        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
+        execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc", "torch"] = "cython",
...
-        assert execution_backend in [
+        assert execution_backend in [
             "dlpack",
             "ctypes",
             "cython",
             "nvrtc",
             "torch",
         ], f"Invalid execution backend. {execution_backend}"

Also applies to: 97-104


34-41: Tighten callable typing for better IDE/help.

Bind torch_function to the instance’s return type.

-    torch_function: Callable = None
+    torch_function: Callable[..., _T] = None
tilelang/autotuner/tuner.py (5)

68-81: Avoid writing logs to CWD; store under cache dir.

Write autotuner.log in TILELANG_CACHE_DIR/autotuner/logs to prevent repo pollution and parallel test clashes.

-    formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
-    file_handler = logging.FileHandler('autotuner.log', mode='w')
+    formatter = logging.Formatter('%(asctime)s %(levelname)s:%(message)s')
+    logs_dir = Path(env.TILELANG_CACHE_DIR) / "autotuner" / "logs"
+    logs_dir.mkdir(parents=True, exist_ok=True)
+    file_handler = logging.FileHandler(logs_dir / 'autotuner.log', mode='w')

480-487: Safer CUDA device scoping inside threads.

Use a context manager instead of mutating global device in worker threads.

-        def cuda_device_wrapper(func, device):
-
-            def inner(**config_arg):
-                torch.cuda.set_device(device)
-                return func(**config_arg)
-            return inner
+        def cuda_device_wrapper(func, device):
+            def inner(**config_arg):
+                with torch.cuda.device(device):
+                    return func(**config_arg)
+            return inner

Also applies to: 491-499


612-649: Typo: get_tunner → get_tuner; update call sites.

Minor naming fix improves readability.

-    def get_tunner(self):
+    def get_tuner(self):
@@
-            autotuner = self.get_tunner()
+            autotuner = self.get_tuner()

503-517: Narrow exception handling and keep details at debug.

Current broad catches hide root causes. Catch TimeoutException explicitly (already), and log other exceptions at warning+debug as you do, but prefer Exception grouping with message plus type.

-            try:
-                result = future.result()
-                results_with_configs.append((result, config))
-            except Exception as e:
+            try:
+                result = future.result()
+                results_with_configs.append((result, config))
+            except Exception as e:  # keep broad, but include type in log
                 logger.debug(
-                    f"Compilation failed for config {config} at index {idx} with error: {e}")
+                    f"Compilation failed for config {config} at index {idx} with error: {type(e).__name__}: {e}")
                 continue
@@
-            except Exception:
+            except Exception:
                 logger.warning(
                     f"An error occurred while testing config {config}, checkout autotuner.log for more details"
                 )
                 logger.debug(f"Error: {traceback.format_exc()}")
                 continue

I can also separate CUDA OOM (torch.cuda.OutOfMemoryError) to degrade gracefully.

Also applies to: 528-536


637-651: Cache key can fail with unhashable args (e.g., tensors).

Consider hashing shapes/dtypes for tensors and converting mappings to tuples.

-        key_args_tuple = args
-        key_kwargs_tuple = tuple(sorted(kwargs.items()))
+        def _keyify(x):
+            import torch, numpy as np
+            if isinstance(x, dict):
+                return tuple(sorted((k, _keyify(v)) for k, v in x.items()))
+            if isinstance(x, (list, tuple)):
+                return tuple(_keyify(v) for v in x)
+            if hasattr(torch, "Tensor") and isinstance(x, torch.Tensor):
+                return ("tensor", tuple(x.shape), str(x.dtype), str(x.device))
+            try:
+                hash(x)
+                return x
+            except Exception:
+                return repr(x)
+        key_args_tuple = _keyify(args)
+        key_kwargs_tuple = _keyify(tuple(sorted(kwargs.items())))
tilelang/jit/__init__.py (5)

45-52: Align execution_backend Literals across API (add "torch", include "nvrtc" where missing).

Ensure types reflect supported backends end-to-end.

 def compile(
-    func: PrimFunc[_KP, _T] = None,
+    func: PrimFunc[_KP, _T] = None,
     out_idx: list[int] | int | None = None,
-    execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
+    execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc", "torch"] = "cython",
@@
-def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
+def par_compile(funcs: Iterable[PrimFunc[_KP, _T]],
                 out_idx: list[int] | int | None = None,
-                execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc"] = "cython",
+                execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc", "torch"] = "cython",
@@
 @dataclass
 class JITImpl(Generic[_P, _KP, _T]):
@@
-    execution_backend: Literal["dlpack", "ctypes", "cython"]
+    execution_backend: Literal["dlpack", "ctypes", "cython", "nvrtc", "torch"]

Also applies to: 94-104, 160-173


124-157: Prefer explicit keyword args for ThreadPoolExecutor and remove duplicate return.

Small clarity/compatibility tweaks.

-    with concurrent.futures.ThreadPoolExecutor(num_workers, 'tl-par-comp') as executor:
+    with concurrent.futures.ThreadPoolExecutor(
+        max_workers=num_workers, thread_name_prefix='tl-par-comp'
+    ) as executor:
@@
-        return results
-    return results
+        return results

171-182: Cache type and values mismatch (Kernel vs JITKernel).

The cache stores JITKernel, not param.Kernel. Fix the annotation to avoid confusion.

-        self._kernel_cache: dict[tuple, Kernel] = {}
+        self._kernel_cache: dict[tuple, JITKernel[_KP, _T]] = {}
@@
-        return self._kernel_cache[key]
+        return self._kernel_cache[key]

Also applies to: 268-272


339-356: Decorator typing: return JITImpl[_P, _KP, _T] and accept PrimFunc in callable path.

Brings the concrete implementation in line with the overloads.

-    def decorator(func: Callable[_P, _T]) -> JITImpl[_P, _T]:
-        if isinstance(func, PrimFunc):
+    def decorator(func: Callable[_P, _T] | PrimFunc) -> JITImpl[_P, _KP, _T]:
+        if isinstance(func, PrimFunc):
             orig_func = func.orig_func
         else:
             orig_func = func
-        return JITImpl(
+        try:
+            func_source = inspect.getsource(orig_func)
+        except Exception:
+            func_source = ""
+        return JITImpl(
             func,
@@
-            compile_flags=compile_flags,
-            func_source=inspect.getsource(orig_func),
-            signature=inspect.signature(orig_func),
+            compile_flags=compile_flags,
+            func_source=func_source,
+            signature=inspect.signature(orig_func),
         )

79-91: Metal target assertion: mention required backend in error message.

Minor UX: include target.kind.name in the assertion failure for clarity.

-    if is_metal_target(target):
-        assert execution_backend == 'torch', 'Currently metal target only support `tl.jit(execution_backend="torch")`'
+    if is_metal_target(target):
+        assert execution_backend == 'torch', (
+            f'Metal target requires execution_backend="torch", got "{execution_backend}".'
+        )
tilelang/language/__init__.py (1)

12-13: API transition to v2 looks correct.

The shift from .tir to .v2 for prim_func and macro aligns with the v2 frontend architecture. The wildcard import properly exposes PrimFunc, prim_func, macro, and all dtype symbols.

Consider removing the commented line 12 entirely if the migration to v2 is complete, or add a TODO comment if it's intentionally kept for reference during the transition period.

testing/python/language/test_tilelang_language_dtype.py (3)

44-50: Refine exception handling.

The test has two issues:

  1. The bare except Exception (line 49) is too broad and could mask unexpected errors
  2. The exception variables e are unused in both handlers

Consider this refactor to be more specific about expected exceptions:

         try:
             dtype(1.0)
             dtype()
-        except TypeError as e:
-            pass
-        except Exception as e:
+        except TypeError:
+            pass  # TypeError is expected for some dtypes
+        except (ValueError, AttributeError) as e:
             errors.append(name)

This makes it clearer which exceptions are expected vs. problematic, and removes the unused variable warnings.


57-57: Consider using underscore for unused context variables.

The Kernel context returns block indices that aren't used in this test. Consider using _ to make the intent explicit:

-        with T.Kernel(128, 128) as (bx, by):
+        with T.Kernel(128, 128) as _:

This silences the linter warning and clarifies that the bindings are intentionally unused.


110-110: Remove unnecessary f-string prefix.

Line 110 uses an f-string without placeholders:

-        assert f'tl.local_var_init' in s
+        assert 'tl.local_var_init' in s
tilelang/language/v2/dtypes.py (1)

82-102: Consider caching the fallback FFI call lookup.

The __dtype_call__ method reconstructs FFI call names on every invocation (lines 86-98). If dtypes not in _dtype_tvmstr2fficall are called frequently, this parsing overhead could add up.

Consider caching the computed call:

+_dtype_call_cache = {}
+
 def __dtype_call__(self: tvm.DataType, expr=None, is_size_var: bool = False) -> tir.Var:
     if self in _dtype_tvmstr2fficall:
         return _dtype_tvmstr2fficall[self](expr, is_size_var)
+    if self in _dtype_call_cache:
+        return _dtype_call_cache[self](expr, is_size_var)
     # try to construct the ffi call
     if self.startswith('uint'):
         val = 'UInt' + self[4:]
     # ... rest of logic ...
     call = getattr(tb_ffi, val, None)
     if call is None:
         raise TypeError(f'Convert to datatype `{self}` is not supported...')
+    _dtype_call_cache[self] = call
     return call(expr, is_size_var)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 20feef2 and a7e2027.

📒 Files selected for processing (11)
  • testing/python/jit/test_tilelang_jit_parcompile.py (1 hunks)
  • testing/python/language/test_tilelang_language_dtype.py (1 hunks)
  • testing/python/transform/test_tilelang_transform_multi_version_buffer.py (2 hunks)
  • tilelang/__init__.py (1 hunks)
  • tilelang/autotuner/tuner.py (5 hunks)
  • tilelang/jit/__init__.py (3 hunks)
  • tilelang/jit/kernel.py (3 hunks)
  • tilelang/language/__init__.py (1 hunks)
  • tilelang/language/v2/__init__.py (1 hunks)
  • tilelang/language/v2/builder.py (1 hunks)
  • tilelang/language/v2/dtypes.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
tilelang/jit/__init__.py (7)
tilelang/language/v2/builder.py (1)
  • PrimFunc (411-420)
tilelang/jit/adapter/utils.py (1)
  • is_metal_target (63-64)
tilelang/jit/kernel.py (3)
  • JITKernel (24-488)
  • out_idx (456-457)
  • get_kernel_source (388-399)
tilelang/utils/target.py (1)
  • determine_target (62-123)
tilelang/cache/__init__.py (1)
  • cached (15-38)
tilelang/cache/kernel_cache.py (1)
  • cached (113-204)
tilelang/jit/param.py (2)
  • Kernel (20-31)
  • get_kernel_source (27-28)
tilelang/autotuner/tuner.py (2)
tilelang/jit/__init__.py (5)
  • jit (275-276)
  • jit (280-291)
  • jit (294-361)
  • JITImpl (161-271)
  • decorator (339-356)
tilelang/jit/kernel.py (2)
  • JITKernel (24-488)
  • out_idx (456-457)
tilelang/language/v2/__init__.py (1)
tilelang/language/v2/builder.py (5)
  • prim_func (116-120)
  • prim_func (476-503)
  • macro (123-132)
  • macro (448-449)
  • PrimFunc (411-420)
testing/python/jit/test_tilelang_jit_parcompile.py (2)
tilelang/jit/__init__.py (5)
  • jit (275-276)
  • jit (280-291)
  • jit (294-361)
  • par_compile (94-157)
  • par_compile (193-216)
tilelang/utils/tensor.py (1)
  • torch_assert_close (221-313)
tilelang/language/v2/builder.py (5)
tilelang/language/kernel.py (1)
  • KernelLaunchFrame (95-226)
tilelang/language/v2/ast.py (25)
  • BaseBuilder (164-238)
  • eval_op (87-114)
  • mutate (530-537)
  • ctx_if (170-171)
  • ctx_then (173-175)
  • ctx_else (177-179)
  • eval (181-182)
  • ctx_for (184-185)
  • ctx_continue (187-188)
  • ctx_break (190-191)
  • ctx_while (193-195)
  • bind (197-198)
  • get_parent_locals (167-168)
  • unwrap_value (200-201)
  • assign_slice (203-204)
  • aug_assign (206-207)
  • aug_assign_slice (209-210)
  • boolop (212-217)
  • ifexp (219-220)
  • ret (222-223)
  • ctx_with (225-226)
  • assert_expr (228-229)
  • rval (231-232)
  • arg (234-235)
  • override (237-238)
tilelang/language/v2/dtypes.py (1)
  • get_tvm_dtype (125-128)
tilelang/language/ast/ir.py (11)
  • meta_var (1731-1750)
  • If (1096-1112)
  • Then (1115-1123)
  • Else (1126-1134)
  • evaluate (1319-1331)
  • buffer_store (1263-1300)
  • alloc_buffer (441-508)
  • block_attr (430-438)
  • LetStmt (880-908)
  • target (1682-1713)
  • Assert (859-877)
tilelang/language/tir/op.py (1)
  • if_then_else (2907-2937)
testing/python/language/test_tilelang_language_dtype.py (4)
tilelang/language/v2/builder.py (2)
  • prim_func (116-120)
  • prim_func (476-503)
tilelang/language/v2/dtypes.py (20)
  • short (133-133)
  • long (135-135)
  • half (136-136)
  • int8 (139-139)
  • int16 (140-140)
  • int32 (141-141)
  • int64 (142-142)
  • uint8 (163-163)
  • uint16 (164-164)
  • uint32 (165-165)
  • uint64 (166-166)
  • float8_e4m3fn (229-229)
  • float8_e4m3fnuz (236-236)
  • float8_e5m2 (243-243)
  • float8_e5m2fnuz (250-250)
  • float8_e8m0fnu (257-257)
  • float16 (187-187)
  • bfloat16 (285-285)
  • float32 (188-188)
  • float64 (189-189)
tilelang/language/ast/ir.py (1)
  • alloc_buffer (441-508)
tilelang/jit/__init__.py (3)
  • jit (275-276)
  • jit (280-291)
  • jit (294-361)
🪛 Ruff (0.14.1)
tilelang/language/__init__.py

12-12: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


13-13: from .v2 import * used; unable to detect undefined names

(F403)


13-13: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

tilelang/jit/__init__.py

102-102: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


151-151: Do not catch blind exception: Exception

(BLE001)


190-190: Prefer TypeError exception for invalid type

(TRY004)


190-190: Avoid specifying long messages outside the exception class

(TRY003)


195-195: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


205-205: Prefer TypeError exception for invalid type

(TRY004)


205-205: Avoid specifying long messages outside the exception class

(TRY003)

tilelang/autotuner/tuner.py

532-532: Do not catch blind exception: Exception

(BLE001)

tilelang/language/v2/__init__.py

1-1: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)


2-2: from .dtypes import * used; unable to detect undefined names

(F403)

tilelang/language/v2/builder.py

42-42: Avoid specifying long messages outside the exception class

(TRY003)


103-103: PEP 484 prohibits implicit Optional

Convert to Optional[T]

(RUF013)


125-127: Avoid specifying long messages outside the exception class

(TRY003)


199-199: Avoid specifying long messages outside the exception class

(TRY003)


207-209: Avoid specifying long messages outside the exception class

(TRY003)


214-214: Avoid specifying long messages outside the exception class

(TRY003)


217-217: Avoid specifying long messages outside the exception class

(TRY003)


219-219: Unused method argument: cond

(ARG002)


220-220: Avoid specifying long messages outside the exception class

(TRY003)


300-300: Avoid specifying long messages outside the exception class

(TRY003)


318-318: Avoid specifying long messages outside the exception class

(TRY003)


367-370: Avoid specifying long messages outside the exception class

(TRY003)


383-383: Avoid specifying long messages outside the exception class

(TRY003)


388-388: Avoid specifying long messages outside the exception class

(TRY003)


391-391: Unused function argument: builder

(ARG001)


458-458: Avoid specifying long messages outside the exception class

(TRY003)

testing/python/language/test_tilelang_language_dtype.py

47-47: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


49-49: Do not catch blind exception: Exception

(BLE001)


49-49: Local variable e is assigned to but never used

Remove assignment to unused variable e

(F841)


57-57: Unpacked variable bx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


57-57: Unpacked variable by is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


110-110: f-string without any placeholders

Remove extraneous f prefix

(F541)


115-115: Local variable buf_1 is assigned to but never used

Remove assignment to unused variable buf_1

(F841)


116-116: Local variable buf_2 is assigned to but never used

Remove assignment to unused variable buf_2

(F841)


117-117: Local variable buf_3 is assigned to but never used

Remove assignment to unused variable buf_3

(F841)


118-118: Local variable buf_4 is assigned to but never used

Remove assignment to unused variable buf_4

(F841)


119-119: Local variable buf_5 is assigned to but never used

Remove assignment to unused variable buf_5

(F841)


120-120: Local variable buf_6 is assigned to but never used

Remove assignment to unused variable buf_6

(F841)


121-121: Local variable buf_7 is assigned to but never used

Remove assignment to unused variable buf_7

(F841)


122-122: Local variable buf_8 is assigned to but never used

Remove assignment to unused variable buf_8

(F841)


123-123: Local variable buf_9 is assigned to but never used

Remove assignment to unused variable buf_9

(F841)


124-124: Local variable buf_10 is assigned to but never used

Remove assignment to unused variable buf_10

(F841)


125-125: Local variable buf_11 is assigned to but never used

Remove assignment to unused variable buf_11

(F841)


126-126: Local variable buf_12 is assigned to but never used

Remove assignment to unused variable buf_12

(F841)


127-127: Local variable buf_13 is assigned to but never used

Remove assignment to unused variable buf_13

(F841)


128-128: Local variable buf_14 is assigned to but never used

Remove assignment to unused variable buf_14

(F841)


129-129: Local variable buf_15 is assigned to but never used

Remove assignment to unused variable buf_15

(F841)


130-130: Local variable buf_16 is assigned to but never used

Remove assignment to unused variable buf_16

(F841)


131-131: Local variable buf_17 is assigned to but never used

Remove assignment to unused variable buf_17

(F841)


132-132: Local variable buf_18 is assigned to but never used

Remove assignment to unused variable buf_18

(F841)


133-133: Local variable buf_19 is assigned to but never used

Remove assignment to unused variable buf_19

(F841)


134-134: Local variable buf_20 is assigned to but never used

Remove assignment to unused variable buf_20

(F841)


135-135: Local variable buf_21 is assigned to but never used

Remove assignment to unused variable buf_21

(F841)


136-136: Local variable buf_22 is assigned to but never used

Remove assignment to unused variable buf_22

(F841)


137-137: Local variable buf_23 is assigned to but never used

Remove assignment to unused variable buf_23

(F841)


138-138: Local variable buf_24 is assigned to but never used

Remove assignment to unused variable buf_24

(F841)

tilelang/language/v2/dtypes.py

95-95: Avoid specifying long messages outside the exception class

(TRY003)


101-101: Avoid specifying long messages outside the exception class

(TRY003)


112-112: Avoid specifying long messages outside the exception class

(TRY003)


599-601: Consider [*_all_dtypes, 'AnyDType', 'get_tvm_dtype'] instead of concatenation

Replace with [*_all_dtypes, 'AnyDType', 'get_tvm_dtype']

(RUF005)

🔇 Additional comments (10)
tilelang/__init__.py (1)

7-7: LGTM! Good improvement for cross-environment compatibility.

Using tqdm.auto automatically selects the appropriate progress bar backend based on the execution environment (e.g., notebook-friendly display in Jupyter, standard console output otherwise). This enhances user experience without affecting functionality, as tqdm.write() on line 38 remains fully compatible.

tilelang/autotuner/tuner.py (1)

46-56: Portability note: signal.SIGALRM timeouts won’t work on Windows.

If Windows support matters, consider a watchdog thread or multiprocessing timeout fallback.

Would you like a small compatibility wrapper that picks SIGALRM on POSIX and a thread-based timer on Windows?

testing/python/transform/test_tilelang_transform_multi_version_buffer.py (1)

116-116: LGTM!

The removal of explicit type annotations aligns with the v2 frontend's type inference capabilities. Both test functions consistently rely on inference for the temporary variable, maintaining functional equivalence while simplifying the syntax.

Also applies to: 128-128

tilelang/language/v2/__init__.py (1)

1-2: LGTM!

The public API surface correctly exposes PrimFunc, prim_func, and macro from the builder module, along with all dtype symbols. The wildcard import from dtypes is appropriate for a comprehensive dtype API.

Note: The static analysis warning about wildcard imports (F403) is expected for public API modules and can be safely ignored in this context.

testing/python/language/test_tilelang_language_dtype.py (4)

7-35: LGTM!

This test validates that all dtype annotations are accepted in function signatures. The no-op body is appropriate for this signature validation test.


112-138: LGTM!

This test validates that buffer allocation succeeds with all supported dtypes. The unused variable warnings are expected for this validation test—the test verifies that allocation doesn't raise exceptions, not that the buffers are used.


140-194: LGTM!

The dtype equality mappings between TileLang and Torch are comprehensive and correctly validated.


197-210: LGTM!

Excellent end-to-end test of the v2 frontend's typed variable declaration and assignment semantics. The test correctly validates that reassignment creates new bindings while maintaining immutability.

tilelang/language/v2/dtypes.py (2)

125-128: LGTM!

The get_tvm_dtype helper provides a clean normalization interface and handles already-typed inputs correctly.


131-440: Extensive dtype coverage looks comprehensive.

The TYPE_CHECKING block and runtime definitions provide excellent coverage of scalar and vector dtypes across int, uint, float, and specialized float8/float6/float4 variants. This enables strong type hinting while maintaining runtime flexibility.

Comment on lines +6 to +24
@tilelang.jit(
out_idx=-1, # create the output tensor during runtime
verbose=True,
)
def matmul_kernel_jit(
M,
N,
K,
block_M,
block_N,
block_K,
trans_A=False,
trans_B=True,
in_dtype='float16',
out_dtype='float32',
accum_dtype='float32',
num_stages=2,
threads=128,
):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion | 🟠 Major

Make test portable across CUDA/MPS and select correct backend for metal.

Currently hard-codes CUDA and default backend; metal requires execution_backend="torch". Choose device/backend at import time and skip if none.

+import pytest
 import tilelang.testing
 import tilelang
 import torch
 
-@tilelang.jit(
-    out_idx=-1,  # create the output tensor during runtime
-    verbose=True,
-)
+# Device/backend selection
+USE_CUDA = torch.cuda.is_available()
+USE_MPS = hasattr(torch.backends, "mps") and torch.backends.mps.is_available()
+DEVICE = "cuda" if USE_CUDA else ("mps" if USE_MPS else None)
+EXEC_BACKEND = "torch" if (USE_MPS and not USE_CUDA) else "cython"
+
+@tilelang.jit(
+    out_idx=-1,  # create the output tensor during runtime
+    verbose=True,
+    execution_backend=EXEC_BACKEND,
+)
 def matmul_kernel_jit(
@@
 def test_par_compile():
+    if DEVICE is None:
+        pytest.skip("No CUDA or MPS device available for JIT test.")
     configs = [
         (1024, 1024, 1024, 128, 128, 32),
         (2048, 2048, 2048, 256, 256, 64),
         (4096, 4096, 4096, 64, 64, 128),
     ]
     kernels = matmul_kernel_jit.par_compile(configs)
     for (M, N, K, _, _, _), kernel in zip(configs, kernels):
-        A = torch.randn(M, K, dtype=torch.float16).cuda()
-        B = torch.randn(N, K, dtype=torch.float16).cuda()
+        A = torch.randn(M, K, dtype=torch.float16, device=DEVICE)
+        B = torch.randn(N, K, dtype=torch.float16, device=DEVICE)
         ref = (A @ B.T).float()
         C = kernel(A, B)
         tilelang.testing.torch_assert_close(C, ref, rtol=1e-2, atol=1e-2)

Also applies to: 58-71

Comment on lines +117 to +122
tvm.DataType.__eq__ = __dtype_eq__
tvm.DataType.__req__ = __dtype_eq__
tvm.DataType.__ne__ = __dtype_ne__
tvm.DataType.__rne__ = __dtype_ne__
tvm.DataType.__call__ = __dtype_call__
tvm.DataType.__new__ = __dtype_new__
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Verify monkey-patched methods don't conflict with TVM's DataType.

The code assigns custom implementations to tvm.DataType.__eq__, __ne__, __call__, and __new__. Ensure these don't conflict with TVM's internal expectations or break existing functionality.

Run this script to check if TVM's DataType has existing implementations that might be overridden:


🏁 Script executed:

#!/bin/bash
# Check TVM DataType implementation for potential conflicts
python3 << 'EOF'
import tvm
from tvm import DataType
import inspect

# Check what methods DataType originally has
print("=== Original DataType methods ===")
for name in ['__eq__', '__ne__', '__call__', '__new__']:
    if hasattr(DataType, name):
        method = getattr(DataType, name)
        print(f"{name}: {method}")
        if hasattr(method, '__module__'):
            print(f"  Defined in: {method.__module__}")

# Check if DataType is a subclass of str (as assumed by the code)
print(f"\n=== DataType MRO ===")
print(f"DataType bases: {DataType.__bases__}")
print(f"Is subclass of str: {issubclass(DataType, str)}")
EOF

Length of output: 115


Let me examine the code directly and search for TVM documentation to understand the context better.




🏁 Script executed:

#!/bin/bash
# Examine the dtypes.py file to understand the monkey-patching context
cat tilelang/language/v2/dtypes.py | head -150 | tail -60

Length of output: 2062


🏁 Script executed:

#!/bin/bash
# Search for where these custom dtype methods are defined
rg -n "__dtype_eq__|__dtype_ne__|__dtype_call__|__dtype_new__" tilelang/language/v2/dtypes.py -B5 -A2

Length of output: 1765


🏁 Script executed:

#!/bin/bash
# Check if there are any tests or documentation about DataType usage
fd -e py -e md | xargs rg -l "tvm\.DataType|from tvm import DataType" | head -10

Length of output: 356


🏁 Script executed:

#!/bin/bash
# Look for imports and TVM usage patterns in the file
head -50 tilelang/language/v2/dtypes.py

Length of output: 2260


🌐 Web query:

TVM DataType class methods documentation

💡 Result:

Here are the main TVM C++ runtime DataType class methods (summary). For full signatures and details see the official API docs below.

Constructors

  • DataType(), DataType(DLDataType), DataType(int code, int bits, int lanes). [1][2]

Static constructors

  • Int/UInt(bits, lanes, is_scalable), Float(bits, lanes), BFloat(bits, lanes), Bool(lanes, is_scalable), Handle(bits, lanes), Void(), ShapeIndex(). [1][2]

Queryors / accessors

  • code(), bits(), bytes(), lanes(), vscale_factor(). [1][2]

Type predicates

  • is_scalar(), is_vector(), is_vector_bool(), is_uint(), is_void(), is_scalable_vector(), is_scalable_or_fixed_length_vector(). [1][2]

Transformers / converters

  • element_of(), with_bits(bits), with_lanes(lanes), with_scalable_vscale_factor(vscale_factor). [1][2]

Operators / conversions

  • operator DLDataType(), operator==, operator!=, operator= . [1][2]

Reference documentation

  • DataType class reference (current site): [1].
  • DataType class reference (v0.10/v0.12 archived page): [2].

Sources
[1] TVM DataType class reference (Apache TVM docs). (tvm.apache.org)
[2] TVM DataType class reference (v0.10/v0.12 pages). (tvm.apache.org)

Citations:





🏁 Script executed:

#!/bin/bash
# Search for __req__ and __rne__ in entire codebase
rg -n "__req__|__rne__" --type py

Length of output: 141


🏁 Script executed:

#!/bin/bash
# Search for test files related to dtypes
fd -e py | xargs rg -l "test.*[Dd]ata[Tt]ype|[Dd]ata[Tt]ype.*test" | head -5

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check if there are any comparison operations used on DataType instances
rg -n "tvm\.DataType.*==" tilelang/ --type py -B2 -A2 | head -40

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for reverse comparison operations on DataType (right-hand side comparisons)
rg -n "(str|int|float).*==.*tvm\.DataType|dtype.*!=" tilelang/ --type py -B1 -A1

Length of output: 6062


🌐 Web query:

Python __req__ __rne__ comparison methods

💡 Result:

Short answer: Python does not provide magic methods named req or rne. Equality/inequality use eq and ne; there is no separate “reflected” equality because the rich-comparison protocol already defines how a right-hand operand can get a try (left’s method is tried first; if it returns NotImplemented the interpreter may call the right-hand object’s corresponding method). [1][2][3]

Why:

  • eq / ne are part of the rich-comparison protocol (PEP 207). The C-level rich-compare slot dispatches using op codes (EQ/NE/etc.). [2][3]
  • Reflected arithmetic methods (e.g. radd) exist because arithmetic operators aren’t defined on every object type by default; equality methods are present on object by default, so a dedicated req would almost never be invoked. The interpreter instead relies on returning NotImplemented (or type/subclass rules) to allow the right-hand object's eq to be tried. [1][2]

Minimal example (allow RHS to handle comparison):

  • Left.eq returns NotImplemented to let Right.eq be called.

    class Left:
    def eq(self, other):
    return NotImplemented

    class Right:
    def eq(self, other):
    return True

    Left() == Right() # -> True, because Left.eq returned NotImplemented so Right.eq was tried

References:

  • StackOverflow discussion on why req doesn’t exist. [1]
  • PEP 207 (rich comparisons). [2]
  • Python “what’s new” / docs describing eq/ne and rich comparison dispatch. [3]

Remove invalid monkey-patched method assignments on lines 118 and 120.

Python does not provide magic methods named __req__ or __rne__. Lines 118 and 120 attempt to assign these non-existent methods to tvm.DataType. They will be silently added as object attributes but never invoked by Python's comparison machinery. The rich-comparison protocol handles reflected operations by allowing the right-hand object's __eq__ and __ne__ to be tried if the left returns NotImplemented—which your __dtype_eq__ and __dtype_ne__ implementations already do correctly.

Remove lines 118 and 120 entirely:

tvm.DataType.__eq__ = __dtype_eq__
# tvm.DataType.__req__ = __dtype_eq__  # DELETE
tvm.DataType.__ne__ = __dtype_ne__
# tvm.DataType.__rne__ = __dtype_ne__  # DELETE
tvm.DataType.__call__ = __dtype_call__
tvm.DataType.__new__ = __dtype_new__

The remaining patches (__eq__, __ne__, __call__, __new__) appear functional and align with the codebase's usage patterns.

🤖 Prompt for AI Agents
In tilelang/language/v2/dtypes.py around lines 117 to 122, remove the two
invalid monkey-patched assignments that add non-existent magic methods (__req__
and __rne__) to tvm.DataType; delete the lines assigning tvm.DataType.__req__ =
__dtype_eq__ and tvm.DataType.__rne__ = __dtype_ne__, leaving the legitimate
assignments for __eq__, __ne__, __call__, and __new__ intact so Python's normal
rich-comparison fallback (using NotImplemented) handles reflected comparisons.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

♻️ Duplicate comments (1)
tilelang/language/v2/dtypes.py (1)

118-123: Remove invalid monkey-patched method assignments.

Lines 119 and 121 assign non-existent magic methods __req__ and __rne__ to tvm.DataType. Python's comparison protocol doesn't support these methods—reflected operations are handled by returning NotImplemented from __eq__ and __ne__, which your implementations already do correctly.

Apply this diff:

 tvm.DataType.__eq__ = __dtype_eq__
-tvm.DataType.__req__ = __dtype_eq__
 tvm.DataType.__ne__ = __dtype_ne__
-tvm.DataType.__rne__ = __dtype_ne__
 tvm.DataType.__call__ = __dtype_call__
 tvm.DataType.__new__ = __dtype_new__
🧹 Nitpick comments (4)
testing/python/language/test_tilelang_language_dtype.py (1)

60-60: Consider using dummy variables for unused kernel context.

The kernel context unpacks bx and by but never uses them. Consider using underscores to signal they're intentionally unused: with T.Kernel(128, 128) as (_, _):.

Apply this diff:

-        with T.Kernel(128, 128) as (bx, by):
+        with T.Kernel(128, 128) as (_, _):
tilelang/language/v2/dtypes.py (3)

82-103: Consider extracting exception messages to module-level constants (optional).

The function correctly implements dtype callable behavior with helpful error messages. The static analysis flags long exception strings, but they provide valuable debugging context for dtype construction failures.

If you prefer to follow the style guide strictly, extract the messages:

_INVALID_DTYPE_MSG = 'Invalid type {dtype}'
_UNSUPPORTED_CONVERSION_MSG = (
    "Convert to datatype `{dtype}` is not supported by tvm\n"
    "calling failed on `tvm.script.ir_builder.tir._ffi_api.{call}`"
)

# Then use in the function:
raise TypeError(_INVALID_DTYPE_MSG.format(dtype=self))
raise TypeError(_UNSUPPORTED_CONVERSION_MSG.format(dtype=self, call=val))

106-115: LGTM!

The custom __new__ implementation correctly handles dtype construction from strings and Python/Torch types. The comprehensive error message (line 113) aids debugging, though it triggers a style warning (TRY003).


908-911: Consider using spread operator for concatenation (optional).

The current concatenation works correctly but could use the more modern spread syntax suggested by Ruff.

Apply this diff:

-__all__ = _all_dtypes + [
+__all__ = [
+    *_all_dtypes,
     'AnyDType',
     'get_tvm_dtype',
 ]
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between a7e2027 and 4d0bc85.

📒 Files selected for processing (2)
  • testing/python/language/test_tilelang_language_dtype.py (1 hunks)
  • tilelang/language/v2/dtypes.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (1)
testing/python/language/test_tilelang_language_dtype.py (4)
tilelang/language/v2/builder.py (2)
  • prim_func (116-120)
  • prim_func (476-503)
tilelang/language/v2/dtypes.py (19)
  • short (137-138)
  • long (143-144)
  • half (146-147)
  • int8 (155-156)
  • int16 (158-159)
  • int32 (161-162)
  • int64 (164-165)
  • uint8 (227-228)
  • uint16 (230-231)
  • uint32 (233-234)
  • uint64 (236-237)
  • float8_e4m3fn (425-426)
  • float8_e4m3fnuz (446-447)
  • float8_e5m2 (467-468)
  • float8_e5m2fnuz (488-489)
  • float8_e8m0fnu (509-510)
  • float16 (299-300)
  • float32 (302-303)
  • float64 (305-306)
tilelang/language/ast/ir.py (1)
  • alloc_buffer (441-508)
tilelang/jit/__init__.py (3)
  • jit (275-276)
  • jit (280-291)
  • jit (294-361)
🪛 Ruff (0.14.1)
testing/python/language/test_tilelang_language_dtype.py

51-51: Do not catch blind exception: Exception

(BLE001)


60-60: Unpacked variable bx is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)


60-60: Unpacked variable by is never used

Prefix it with an underscore or any other dummy variable pattern

(RUF059)

tilelang/language/v2/dtypes.py

95-95: Avoid specifying long messages outside the exception class

(TRY003)


101-102: Avoid specifying long messages outside the exception class

(TRY003)


113-113: Avoid specifying long messages outside the exception class

(TRY003)


908-911: Consider [*_all_dtypes, 'AnyDType', 'get_tvm_dtype'] instead of concatenation

Replace with [*_all_dtypes, 'AnyDType', 'get_tvm_dtype']

(RUF005)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Test for Python 3.12 with Metal (on macos-latest)
🔇 Additional comments (14)
testing/python/language/test_tilelang_language_dtype.py (7)

1-6: LGTM!

The imports are appropriate for testing dtype functionality and integration with PyTorch and TVM.


8-37: LGTM!

This test validates that all 24 dtype aliases can be used as type annotations in prim_func signatures, which is a good smoke test for the new dtype system.


40-53: LGTM!

The test correctly validates dtype construction and callable behavior. The broad exception catch on line 51 is intentional—it collects unexpected errors while allowing expected TypeError to pass silently.


61-113: LGTM!

The test thoroughly validates the variable declaration sugar syntax for all 24 dtypes, including both initialization and reassignment patterns. The assertions correctly verify the compiled output contains the expected variable declarations.


116-143: LGTM!

This smoke test validates that all 24 dtypes work correctly in buffer allocation contexts and have proper string representations for TVM's buffer API.


203-218: LGTM!

This test correctly validates the variable assignment semantics: b captures the initial value of a (1), and d captures the reassigned value (2). The assertions confirm correct behavior.


221-222: LGTM!

Standard test runner invocation.

tilelang/language/v2/dtypes.py (7)

1-11: LGTM!

Imports are appropriate. VoidPtr serves as a marker type for void pointer handling in the dtype system.


14-48: LGTM!

The AnyDType union and _dtype_cvt mapping table provide comprehensive cross-representation dtype support between Python, PyTorch, TVM, ctypes, and CFFI. The use of None for unsupported mappings (e.g., float16 in ctypes) is appropriate.


51-63: LGTM!

The _create_type_mapper factory cleanly generates bidirectional dtype mappings from the central _dtype_cvt table, filtering out unsupported conversions appropriately.


66-79: LGTM!

The custom equality/inequality implementations correctly handle string and Python/Torch dtype comparisons. Returning NotImplemented for unsupported types allows Python's comparison protocol to try the reverse operation.


126-129: LGTM!

The get_tvm_dtype helper provides a clean API for dtype resolution, handling both TVM dtypes and IR types directly.


132-749: LGTM!

The extensive dtype definitions use the TYPE_CHECKING guard appropriately to provide type hints without runtime overhead, while the runtime instantiation creates the actual tvm.DataType instances. This comprehensive set covers scalar and vector dtypes across int, uint, float, and exotic float8 variants.


751-906: LGTM!

The _all_dtypes list provides a complete registry of all supported dtypes, enabling runtime iteration and validation as seen in the tests.

Comment on lines +147 to +200
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Duplicate dtype comparison at indices 3 and 6.

Both T.long and torch.long appear twice in the test lists (lines 151/154 and 177/180). This appears to be unintentional duplication.

Apply this diff to test a different dtype at position 6:

         T.half,
         T.float,
-        T.long,
+        T.double,
         T.int8,
         torch.half,
         torch.float,
-        torch.long,
+        torch.double,
         torch.int8,
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.long,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.long,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
dtypes = [
T.bool,
T.short,
T.int,
T.long,
T.half,
T.float,
T.double,
T.int8,
T.int16,
T.int32,
T.int64,
T.uint8,
T.uint16,
T.uint32,
T.uint64,
T.float8_e4m3fn,
T.float8_e4m3fnuz,
T.float8_e5m2,
T.float8_e5m2fnuz,
T.float8_e8m0fnu,
T.float16,
T.bfloat16,
T.float32,
T.float64,
]
torch_dtypes = [
torch.bool,
torch.short,
torch.int,
torch.long,
torch.half,
torch.float,
torch.double,
torch.int8,
torch.int16,
torch.int32,
torch.int64,
torch.uint8,
torch.uint16,
torch.uint32,
torch.uint64,
torch.float8_e4m3fn,
torch.float8_e4m3fnuz,
torch.float8_e5m2,
torch.float8_e5m2fnuz,
torch.float8_e8m0fnu,
torch.float16,
torch.bfloat16,
torch.float32,
torch.float64,
]
for a, b in zip(dtypes, torch_dtypes):
assert a == b, f"{a} and {b} are not equal"
🤖 Prompt for AI Agents
In testing/python/language/test_tilelang_language_dtype.py around lines 147 to
200, the dtype pair at the 7th position is duplicated (T.long / torch.long
appear twice); replace the second occurrence (the entries at lines ~154 and
~180) with the intended dtype so both lists remain identical (update both dtypes
arrays consistently to the correct dtype and re-run tests to confirm).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant