Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 106 additions & 0 deletions tilelang/language/tir/ir.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
from typing import TypeVar, Literal
from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm

_T = TypeVar('_T')

def abs(x: _T, span: Span | None=None) -> _T: ...
def acos(x: _T) -> _T: ...
def acosh(x: _T) -> _T: ...
def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ...
def asin(x: _T) -> _T: ...
def asinh(x: _T) -> _T: ...
def atan(x: _T) -> _T: ...
def atan2(x1: _T, x2: _T) -> _T: ...
def atanh(x: _T) -> _T: ...
def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_not(x: _T, span: Span | None=None) -> _T: ...
def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ...
def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ...
def ceil(x: _T, span: Span | None=None) -> _T: ...
def clz(x: _T) -> _T: ...
def copysign(x1: _T, x2: _T) -> _T: ...
def cos(x: _T) -> _T: ...
def cosh(x: _T) -> _T: ...
def erf(x: _T) -> _T: ...
def exp(x: _T) -> _T: ...
def exp2(x: _T) -> _T: ...
def exp10(x: _T) -> _T: ...
def floor(x: _T, span: Span | None=None) -> _T: ...
def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ...
def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def fmod(x: _T, y: _T) -> _T: ...
def hypot(x1: _T, x2: _T) -> _T: ...
def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ...
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Incorrect type annotation for dtype parameter.

The dtype parameter is typed as _T, but based on the parameter name and typical usage, it should be str.

Apply this diff:

-def infinity(dtype: _T, span: Span | None=None) -> _T: ...
+def infinity(dtype: str, span: Span | None=None) -> PrimExpr: ...
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def infinity(dtype: _T, span: Span | None=None) -> _T: ...
def infinity(dtype: str, span: Span | None=None) -> PrimExpr: ...
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 35, the dtype parameter is
incorrectly annotated as _T; change the dtype parameter's type annotation from
_T to str (i.e., make the signature use dtype: str) while keeping the rest of
the signature intact so callers see the correct parameter type.

def isfinite(x: _T, span: Span | None=None) -> _T: ...
def isinf(x: _T, span: Span | None=None) -> _T: ...
def isnan(x: _T, span: Span | None=None) -> _T: ...
def isnullptr(x: _T, span: Span | None=None) -> _T: ...
def ldexp(x1: _T, x2: _T) -> _T: ...
def likely(cond: _T, span: Span | None=None) -> _T: ...
def log(x: _T) -> _T: ...
def log1p(x: _T) -> _T: ...
def log2(x: _T) -> _T: ...
def log10(x: _T) -> _T: ...
def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ...
def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ...
def nearbyint(x: _T, span: Span | None=None) -> _T: ...
def nextafter(x1: _T, x2: _T) -> _T: ...
def popcount(x: _T) -> _T: ...
def pow(x: _T, y: _T, span: Span | None=None) -> _T: ...
def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ...
def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ...
def ret(val: _T) -> _T: ...
def round(x: _T, span: Span | None=None) -> _T: ...
def rsqrt(x: _T) -> _T: ...
def shift_left(x: _T, y: _T, span=None) -> _T: ...
def shift_right(x: _T, y: _T, span=None) -> _T: ...
Comment on lines +58 to +59
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Add type annotation for span parameter.

The span parameter lacks a type annotation, unlike all other functions with the same parameter. This breaks consistency and reduces type safety.

Apply this diff:

-def shift_left(x: _T, y: _T, span=None) -> _T: ...
-def shift_right(x: _T, y: _T, span=None) -> _T: ...
+def shift_left(x: _T, y: _T, span: Span | None=None) -> _T: ...
+def shift_right(x: _T, y: _T, span: Span | None=None) -> _T: ...
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def shift_left(x: _T, y: _T, span=None) -> _T: ...
def shift_right(x: _T, y: _T, span=None) -> _T: ...
def shift_left(x: _T, y: _T, span: Span | None=None) -> _T: ...
def shift_right(x: _T, y: _T, span: Span | None=None) -> _T: ...
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around lines 58-59, the span parameter is
missing a type annotation; update both function signatures to add an explicit
optional Span type (e.g. span: Span | None = None or span: Optional[Span] = None
to match project typing style) so they match other functions with span and
restore type consistency.

def sigmoid(x: _T) -> _T: ...
def sin(x: _T) -> _T: ...
def sinh(x: _T) -> _T: ...
def sqrt(x: _T) -> _T: ...
def tan(x: _T) -> _T: ...
def tanh(x: _T) -> _T: ...
def trunc(x: _T, span: Span | None=None) -> _T: ...
def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ...
def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ...
def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Add type annotation for the data parameter.

The data parameter lacks a type annotation, reducing type safety and consistency with other parameters in the file.

Verify the expected type from the runtime implementation to add the correct annotation (likely PrimExpr or Var).


🏁 Script executed:

#!/bin/bash
# Find the runtime implementation to determine the correct type for the data parameter
ast-grep --pattern 'def tvm_access_ptr($$$)'

Length of output: 1513



---

🏁 Script executed:

```shell
#!/bin/bash
# Search for calls to tvm_access_ptr to see what types are passed for the data parameter
rg -n 'tvm_access_ptr\s*\(' -A 2 -B 2 --type py

Length of output: 15734


🏁 Script executed:

#!/bin/bash
# Check for type imports and similar buffer/pointer-like types in the file
rg -n 'Buffer|Var|Handle' tilelang/language/tir/ir.pyi | head -20

Length of output: 1180


🏁 Script executed:

#!/bin/bash
# Look at other function signatures in the stub file for patterns with similar parameters
rg -n 'def.*data.*:' tilelang/language/tir/ir.pyi -A 1 | head -30

Length of output: 395


Add type annotation for the data parameter: use PrimExpr.

The data parameter lacks a type annotation. Based on the docstring in the implementation and usage patterns across test files where buffer .data attributes are passed (e.g., A_shared.data), the type should be PrimExpr, consistent with similar functions in the file like tvm_stack_make_array.

🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 69, the tvm_access_ptr declaration
is missing a type annotation for the data parameter; update the function
signature so the data parameter is annotated as PrimExpr (matching usage of
buffer .data and other functions like tvm_stack_make_array) so the stub
accurately reflects expected types.

def tvm_throw_last_error() -> _T: ...
def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ...
def tvm_stack_make_shape(*args) -> _T: ...
def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Add type annotation for the elem_offset parameter.

The elem_offset parameter lacks a type annotation.

Verify the expected type from the runtime implementation.


🏁 Script executed:

#!/bin/bash
# Find the runtime implementation
ast-grep --pattern 'def tvm_stack_make_array($$$)'

Length of output: 1753



---

🏁 Script executed:

```shell
#!/bin/bash
# Check imports and type definitions in ir.pyi
head -100 tilelang/language/tir/ir.pyi | grep -E "^(import|from|class PrimExpr|class Expr)"

Length of output: 100


🏁 Script executed:

#!/bin/bash
# Search for PrimExpr and Expr type definitions/imports
rg -n "(class PrimExpr|class Expr|from.*import.*PrimExpr|from.*import.*Expr)" tilelang/language/tir/ir.pyi -A2 -B2

Length of output: 128


🏁 Script executed:

#!/bin/bash
# Check what type annotations are used in similar functions
rg -n "def.*Expr.*Expr.*->" tilelang/language/tir/ir.pyi | head -20

Length of output: 1026


Add type annotation PrimExpr for the elem_offset parameter.

The elem_offset parameter lacks a type annotation. Based on the stub file's import scope (only PrimExpr is imported from tvm.tir.expr) and the consistent pattern used across all other parameters in this function and similar functions in the file, the type should be PrimExpr.

def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset: PrimExpr) -> PrimExpr: ...
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 73, the function
tvm_stack_make_array is missing a type annotation for the elem_offset parameter;
add the annotation elem_offset: PrimExpr to match the other parameters and the
file's import usage so the signature becomes def tvm_stack_make_array(data:
PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype:
PrimExpr, elem_offset: PrimExpr) -> PrimExpr: ...

def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ...
def call_packed(*args, span=None) -> _T: ...
def call_cpacked(*args, span=None) -> _T: ...
def call_packed_lowered(*args, span=None) -> _T: ...
def call_cpacked_lowered(*args, span=None) -> _T: ...
def tvm_tuple(*value) -> _T: ...
def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Add type annotation for the arr parameter.

The arr parameter lacks a type annotation.

Verify the expected type from the runtime implementation.


🏁 Script executed:

#!/bin/bash
# Find the runtime implementation
ast-grep --pattern 'def tvm_struct_set($$$)'

Length of output: 1309



Wait, that won't work without cloning. Let me check properly:

---

🏁 Script executed:

```shell
#!/bin/bash
# Inspect imports and related type definitions in the stub file
head -100 tilelang/language/tir/ir.pyi | cat

Length of output: 5583


🏁 Script executed:

#!/bin/bash
# Inspect the top of the stub file for imports and type definitions
head -100 tilelang/language/tir/ir.pyi

Length of output: 5583


🏁 Script executed:

#!/bin/bash
# Search for StructType or similar types in ir.pyi
rg -n "StructType|class.*Type" tilelang/language/tir/ir.pyi | head -20

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Check imports at the top of ir.pyi
rg -n "^from|^import" tilelang/language/tir/ir.pyi | head -20

Length of output: 104


🏁 Script executed:

#!/bin/bash
# Search for StructType definition or usage in codebase
rg -n "StructType" tilelang/language/tir/ --type py --type pyi

Length of output: 27


🏁 Script executed:

#!/bin/bash
# Examine the implementation context of tvm_struct_set in op.py
sed -n '439,465p' tilelang/language/tir/op.py

Length of output: 556


🏁 Script executed:

#!/bin/bash
# Search for struct-related functions in op.py to find usage patterns
rg -n "struct" tilelang/language/tir/op.py -i | head -20

Length of output: 1053


🏁 Script executed:

#!/bin/bash
# Look for tvm_struct_get (companion function) to see if it has type info
rg -n "tvm_struct_get" tilelang/language/tir/ -A 10

Length of output: 2115


Add type annotation for the arr parameter.

The arr parameter lacks a type annotation. Based on the runtime implementation and function signature, it should be typed as PrimExpr:

def tvm_struct_set(arr: PrimExpr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 80, the tvm_struct_set function
signature is missing a type annotation for the arr parameter; update the
declaration to annotate arr as PrimExpr so the signature reads with arr:
PrimExpr, e.g. change the parameter list to include arr: PrimExpr and keep the
rest unchanged.

def tvm_thread_invariant(cond: _T) -> _T: ...
def tvm_thread_allreduce(*freduce_args) -> _T: ...
def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ...
def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ...
def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ...
def ptx_wait_group(num: int) -> PrimExpr: ...
def ptx_commit_group() -> _T: ...
def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ...
def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ...
def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ...
def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ...
def create_barriers(barrier_count: int) -> PrimExpr: ...
def assume(cond: _T=None) -> _T: ...
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Problematic default value with generic TypeVar.

Using None as a default value with an unbounded _T is problematic because None may not be a valid instance of _T.

Consider one of these fixes:

-def assume(cond: _T=None) -> _T: ...
+def assume(cond: _T | None=None) -> _T: ...

Or use Optional:

+from typing import Optional
-def assume(cond: _T=None) -> _T: ...
+def assume(cond: Optional[_T]=None) -> _T: ...
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
def assume(cond: _T=None) -> _T: ...
def assume(cond: _T | None=None) -> _T: ...
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around line 96, the signature def assume(cond:
_T=None) -> _T: uses None as a default for an unconstrained TypeVar which is
invalid; change the signature so the types match — either remove the default and
keep it typed as def assume(cond: _T) -> _T:, or make the parameter and return
optional by importing typing.Optional and using def assume(cond: Optional[_T] =
None) -> Optional[_T]: depending on whether None is a valid value for callers.

def undef() -> _T: ...
def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ...
def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ...
def start_profile_intrinsic(id: int) -> PrimExpr: ...
def end_profile_intrinsic(id: int) -> PrimExpr: ...
def anylist_getitem(list_handle, index) -> PrimExpr: ...
def anylist_resetitem(list_handle, index) -> PrimExpr: ...
def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ...
def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ...
Comment on lines +102 to +105
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

Add type annotations for parameters in anylist functions.

Multiple parameters (list_handle, index, func_name) lack type annotations, reducing type safety.

Verify the expected types from the runtime implementation.


🏁 Script executed:

#!/bin/bash
# Find the runtime implementations
ast-grep --pattern 'def anylist_$_($$$ )'

Length of output: 0


🏁 Script executed:

#!/bin/bash
# Search for anylist function definitions in Python files
rg -n 'def anylist_' --type=py

Length of output: 358


🏁 Script executed:

#!/bin/bash
# Search for anylist references to understand usage patterns
rg -n 'anylist_' --type=py -C 2 | head -50

Length of output: 2975


🏁 Script executed:

#!/bin/bash
# Check the exact structure of the .pyi file around the functions
fd -e pyi -name 'ir.pyi' -exec cat {} \;

Length of output: 230


🏁 Script executed:

#!/bin/bash
# Get full docstrings from op.py implementations
sed -n '3331,3343p' tilelang/language/tir/op.py
sed -n '3345,3357p' tilelang/language/tir/op.py
sed -n '3359,3375p' tilelang/language/tir/op.py
sed -n '3377,3393p' tilelang/language/tir/op.py

Length of output: 1551


🏁 Script executed:

#!/bin/bash
# View the .pyi stub file
cat tilelang/language/tir/ir.pyi

Length of output: 5950


Add type annotations for parameters in anylist functions.

The runtime implementations in tilelang/language/tir/op.py document the expected parameter types via docstrings: list_handle: Var, index: int, and func_name: str. These should be reflected in the stub file. The Var type is already imported. Add the following annotations:

def anylist_getitem(list_handle: Var, index: int) -> PrimExpr: ...
def anylist_resetitem(list_handle: Var, index: int) -> PrimExpr: ...
def anylist_setitem_call_packed(list_handle: Var, index: int, func_name: str, *args) -> PrimExpr: ...
def anylist_setitem_call_cpacked(list_handle: Var, index: int, func_name: str, *args) -> PrimExpr: ...
🤖 Prompt for AI Agents
In tilelang/language/tir/ir.pyi around lines 102 to 105, the anylist_* function
stubs lack parameter type annotations; update each signature to annotate
list_handle as Var, index as int, and func_name as str where applicable (keep
*args untyped) so they match the runtime docstrings and existing imports; ensure
the return type remains PrimExpr.

def vscale() -> _T: ...
Loading