-
Notifications
You must be signed in to change notification settings - Fork 332
[Language] Add type stubs for tir op #1239
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|
| @@ -0,0 +1,106 @@ | ||||||||||
| from typing import TypeVar, Literal | ||||||||||
| from tvm.tir.expr import Span, PrimExpr, BufferLoad, Var, IntImm | ||||||||||
|
|
||||||||||
| _T = TypeVar('_T') | ||||||||||
|
|
||||||||||
| def abs(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def acos(x: _T) -> _T: ... | ||||||||||
| def acosh(x: _T) -> _T: ... | ||||||||||
| def address_of(buffer_load: BufferLoad, span: Span | None=None) -> PrimExpr: ... | ||||||||||
| def asin(x: _T) -> _T: ... | ||||||||||
| def asinh(x: _T) -> _T: ... | ||||||||||
| def atan(x: _T) -> _T: ... | ||||||||||
| def atan2(x1: _T, x2: _T) -> _T: ... | ||||||||||
| def atanh(x: _T) -> _T: ... | ||||||||||
| def bitwise_and(x: _T, y: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def bitwise_not(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def bitwise_or(x: _T, y: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def bitwise_xor(x: _T, y: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def ceil(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def clz(x: _T) -> _T: ... | ||||||||||
| def copysign(x1: _T, x2: _T) -> _T: ... | ||||||||||
| def cos(x: _T) -> _T: ... | ||||||||||
| def cosh(x: _T) -> _T: ... | ||||||||||
| def erf(x: _T) -> _T: ... | ||||||||||
| def exp(x: _T) -> _T: ... | ||||||||||
| def exp2(x: _T) -> _T: ... | ||||||||||
| def exp10(x: _T) -> _T: ... | ||||||||||
| def floor(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def ceildiv(lhs: _T, rhs: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def floordiv(a: _T, b: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def floormod(a: _T, b: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def fmod(x: _T, y: _T) -> _T: ... | ||||||||||
| def hypot(x1: _T, x2: _T) -> _T: ... | ||||||||||
| def if_then_else(cond: PrimExpr, t: _T, f: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def infinity(dtype: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def isfinite(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def isinf(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def isnan(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def isnullptr(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def ldexp(x1: _T, x2: _T) -> _T: ... | ||||||||||
| def likely(cond: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def log(x: _T) -> _T: ... | ||||||||||
| def log1p(x: _T) -> _T: ... | ||||||||||
| def log2(x: _T) -> _T: ... | ||||||||||
| def log10(x: _T) -> _T: ... | ||||||||||
| def lookup_param(param_name: str, span: Span | None=None) -> PrimExpr: ... | ||||||||||
| def max_value(dtype: str, span: Span | None=None) -> PrimExpr: ... | ||||||||||
| def min_value(dtype: str, span: Span | None=None) -> PrimExpr: ... | ||||||||||
| def nearbyint(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def nextafter(x1: _T, x2: _T) -> _T: ... | ||||||||||
| def popcount(x: _T) -> _T: ... | ||||||||||
| def pow(x: _T, y: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def q_multiply_shift(x: _T, y: _T, q: _T, s: _T) -> _T: ... | ||||||||||
| def q_multiply_shift_per_axis(x: _T, y: _T, ls: _T, rs: _T, q: IntImm, is_lshift_required: IntImm, is_rshift_required: IntImm) -> PrimExpr: ... | ||||||||||
| def ret(val: _T) -> _T: ... | ||||||||||
| def round(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def rsqrt(x: _T) -> _T: ... | ||||||||||
| def shift_left(x: _T, y: _T, span=None) -> _T: ... | ||||||||||
| def shift_right(x: _T, y: _T, span=None) -> _T: ... | ||||||||||
|
Comment on lines
+58
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add type annotation for span parameter. The Apply this diff: -def shift_left(x: _T, y: _T, span=None) -> _T: ...
-def shift_right(x: _T, y: _T, span=None) -> _T: ...
+def shift_left(x: _T, y: _T, span: Span | None=None) -> _T: ...
+def shift_right(x: _T, y: _T, span: Span | None=None) -> _T: ...📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||
| def sigmoid(x: _T) -> _T: ... | ||||||||||
| def sin(x: _T) -> _T: ... | ||||||||||
| def sinh(x: _T) -> _T: ... | ||||||||||
| def sqrt(x: _T) -> _T: ... | ||||||||||
| def tan(x: _T) -> _T: ... | ||||||||||
| def tanh(x: _T) -> _T: ... | ||||||||||
| def trunc(x: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def truncdiv(a: _T, b: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def truncmod(a: _T, b: _T, span: Span | None=None) -> _T: ... | ||||||||||
| def tvm_access_ptr(ptype: PrimExpr, data, offset: int, extent: int, rw_mask: int) -> PrimExpr: ... | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainAdd type annotation for the The Verify the expected type from the runtime implementation to add the correct annotation (likely 🏁 Script executed: #!/bin/bash
# Find the runtime implementation to determine the correct type for the data parameter
ast-grep --pattern 'def tvm_access_ptr($$$)'Length of output: 1513 Length of output: 15734 🏁 Script executed: #!/bin/bash
# Check for type imports and similar buffer/pointer-like types in the file
rg -n 'Buffer|Var|Handle' tilelang/language/tir/ir.pyi | head -20Length of output: 1180 🏁 Script executed: #!/bin/bash
# Look at other function signatures in the stub file for patterns with similar parameters
rg -n 'def.*data.*:' tilelang/language/tir/ir.pyi -A 1 | head -30Length of output: 395 Add type annotation for the The 🤖 Prompt for AI Agents |
||||||||||
| def tvm_throw_last_error() -> _T: ... | ||||||||||
| def tvm_stack_alloca(dtype_str: str, num: int) -> PrimExpr: ... | ||||||||||
| def tvm_stack_make_shape(*args) -> _T: ... | ||||||||||
| def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset) -> PrimExpr: ... | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainAdd type annotation for the The Verify the expected type from the runtime implementation. 🏁 Script executed: #!/bin/bash
# Find the runtime implementation
ast-grep --pattern 'def tvm_stack_make_array($$$)'Length of output: 1753 Length of output: 100 🏁 Script executed: #!/bin/bash
# Search for PrimExpr and Expr type definitions/imports
rg -n "(class PrimExpr|class Expr|from.*import.*PrimExpr|from.*import.*Expr)" tilelang/language/tir/ir.pyi -A2 -B2Length of output: 128 🏁 Script executed: #!/bin/bash
# Check what type annotations are used in similar functions
rg -n "def.*Expr.*Expr.*->" tilelang/language/tir/ir.pyi | head -20Length of output: 1026 Add type annotation The def tvm_stack_make_array(data: PrimExpr, shape: PrimExpr, strides: PrimExpr, ndim: PrimExpr, arr_dtype: PrimExpr, elem_offset: PrimExpr) -> PrimExpr: ...🤖 Prompt for AI Agents |
||||||||||
| def tvm_check_return(expected: int, return_unexpected: int, nested_call: PrimExpr) -> PrimExpr: ... | ||||||||||
| def call_packed(*args, span=None) -> _T: ... | ||||||||||
| def call_cpacked(*args, span=None) -> _T: ... | ||||||||||
| def call_packed_lowered(*args, span=None) -> _T: ... | ||||||||||
| def call_cpacked_lowered(*args, span=None) -> _T: ... | ||||||||||
| def tvm_tuple(*value) -> _T: ... | ||||||||||
| def tvm_struct_set(arr, index: int, field: int, value: PrimExpr) -> PrimExpr: ... | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainAdd type annotation for the The Verify the expected type from the runtime implementation. 🏁 Script executed: #!/bin/bash
# Find the runtime implementation
ast-grep --pattern 'def tvm_struct_set($$$)'Length of output: 1309 Length of output: 5583 🏁 Script executed: #!/bin/bash
# Inspect the top of the stub file for imports and type definitions
head -100 tilelang/language/tir/ir.pyiLength of output: 5583 🏁 Script executed: #!/bin/bash
# Search for StructType or similar types in ir.pyi
rg -n "StructType|class.*Type" tilelang/language/tir/ir.pyi | head -20Length of output: 0 🏁 Script executed: #!/bin/bash
# Check imports at the top of ir.pyi
rg -n "^from|^import" tilelang/language/tir/ir.pyi | head -20Length of output: 104 🏁 Script executed: #!/bin/bash
# Search for StructType definition or usage in codebase
rg -n "StructType" tilelang/language/tir/ --type py --type pyiLength of output: 27 🏁 Script executed: #!/bin/bash
# Examine the implementation context of tvm_struct_set in op.py
sed -n '439,465p' tilelang/language/tir/op.pyLength of output: 556 🏁 Script executed: #!/bin/bash
# Search for struct-related functions in op.py to find usage patterns
rg -n "struct" tilelang/language/tir/op.py -i | head -20Length of output: 1053 🏁 Script executed: #!/bin/bash
# Look for tvm_struct_get (companion function) to see if it has type info
rg -n "tvm_struct_get" tilelang/language/tir/ -A 10Length of output: 2115 Add type annotation for the The def tvm_struct_set(arr: PrimExpr, index: int, field: int, value: PrimExpr) -> PrimExpr: ...🤖 Prompt for AI Agents |
||||||||||
| def tvm_thread_invariant(cond: _T) -> _T: ... | ||||||||||
| def tvm_thread_allreduce(*freduce_args) -> _T: ... | ||||||||||
| def tvm_load_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... | ||||||||||
| def tvm_mma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... | ||||||||||
| def tvm_bmma_sync(fragment_d: Var, index_d: PrimExpr, fragment_a: Var, index_a: PrimExpr, fragment_b: Var, index_b: PrimExpr, fragment_c: Var, index_c: PrimExpr) -> PrimExpr: ... | ||||||||||
| def tvm_fill_fragment(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, value: PrimExpr) -> PrimExpr: ... | ||||||||||
| def tvm_store_matrix_sync(fragment: Var, m: IntImm, n: IntImm, k: IntImm, index: PrimExpr, buffer_ptr: PrimExpr, stride: PrimExpr, layout: Literal['row_major', 'column_major']) -> PrimExpr: ... | ||||||||||
| def ptx_wait_group(num: int) -> PrimExpr: ... | ||||||||||
| def ptx_commit_group() -> _T: ... | ||||||||||
| def ptx_cp_async_barrier(barrier_id: int) -> PrimExpr: ... | ||||||||||
| def ptx_init_barrier_thread_count(barrier_id: int, thread_count: int) -> PrimExpr: ... | ||||||||||
| def ptx_arrive_barrier(barrier_id: int) -> PrimExpr: ... | ||||||||||
| def ptx_arrive_barrier_expect_tx(barrier_id: int, byte_count: int) -> PrimExpr: ... | ||||||||||
| def ptx_wait_barrier(barrier_id: int) -> PrimExpr: ... | ||||||||||
| def create_barriers(barrier_count: int) -> PrimExpr: ... | ||||||||||
| def assume(cond: _T=None) -> _T: ... | ||||||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Problematic default value with generic TypeVar. Using Consider one of these fixes: -def assume(cond: _T=None) -> _T: ...
+def assume(cond: _T | None=None) -> _T: ...Or use +from typing import Optional
-def assume(cond: _T=None) -> _T: ...
+def assume(cond: Optional[_T]=None) -> _T: ...📝 Committable suggestion
Suggested change
🤖 Prompt for AI Agents |
||||||||||
| def undef() -> _T: ... | ||||||||||
| def TVMBackendAllocWorkspace(device_type: int, device_id: int, nbytes: int, dtype_code_hint: int, dtype_bits_hint: int) -> PrimExpr: ... | ||||||||||
| def TVMBackendFreeWorkspace(device_type: int, device_id: int, ptr: Var) -> PrimExpr: ... | ||||||||||
| def start_profile_intrinsic(id: int) -> PrimExpr: ... | ||||||||||
| def end_profile_intrinsic(id: int) -> PrimExpr: ... | ||||||||||
| def anylist_getitem(list_handle, index) -> PrimExpr: ... | ||||||||||
| def anylist_resetitem(list_handle, index) -> PrimExpr: ... | ||||||||||
| def anylist_setitem_call_packed(list_handle, index, func_name, *args) -> PrimExpr: ... | ||||||||||
| def anylist_setitem_call_cpacked(list_handle, index, func_name, *args) -> PrimExpr: ... | ||||||||||
|
Comment on lines
+102
to
+105
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 🧩 Analysis chainAdd type annotations for parameters in anylist functions. Multiple parameters ( Verify the expected types from the runtime implementation. 🏁 Script executed: #!/bin/bash
# Find the runtime implementations
ast-grep --pattern 'def anylist_$_($$$ )'Length of output: 0 🏁 Script executed: #!/bin/bash
# Search for anylist function definitions in Python files
rg -n 'def anylist_' --type=pyLength of output: 358 🏁 Script executed: #!/bin/bash
# Search for anylist references to understand usage patterns
rg -n 'anylist_' --type=py -C 2 | head -50Length of output: 2975 🏁 Script executed: #!/bin/bash
# Check the exact structure of the .pyi file around the functions
fd -e pyi -name 'ir.pyi' -exec cat {} \;Length of output: 230 🏁 Script executed: #!/bin/bash
# Get full docstrings from op.py implementations
sed -n '3331,3343p' tilelang/language/tir/op.py
sed -n '3345,3357p' tilelang/language/tir/op.py
sed -n '3359,3375p' tilelang/language/tir/op.py
sed -n '3377,3393p' tilelang/language/tir/op.pyLength of output: 1551 🏁 Script executed: #!/bin/bash
# View the .pyi stub file
cat tilelang/language/tir/ir.pyiLength of output: 5950 Add type annotations for parameters in anylist functions. The runtime implementations in def anylist_getitem(list_handle: Var, index: int) -> PrimExpr: ...
def anylist_resetitem(list_handle: Var, index: int) -> PrimExpr: ...
def anylist_setitem_call_packed(list_handle: Var, index: int, func_name: str, *args) -> PrimExpr: ...
def anylist_setitem_call_cpacked(list_handle: Var, index: int, func_name: str, *args) -> PrimExpr: ...🤖 Prompt for AI Agents |
||||||||||
| def vscale() -> _T: ... | ||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect type annotation for
dtypeparameter.The
dtypeparameter is typed as_T, but based on the parameter name and typical usage, it should bestr.Apply this diff:
📝 Committable suggestion
🤖 Prompt for AI Agents