Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,8 @@ def match_buffer(
raise ValueError("Shape must be specified when binding input param")
shape = (shape,) if isinstance(shape, (PrimExpr, Integral)) else shape
if strides is not None:
strides = [Var(s, "int32") if isinstance(s, str) else s for s in strides]
idx_dtype = shape[0].dtype if isinstance(shape[0], PrimExpr) else "int32"
strides = [Var(s, idx_dtype) if isinstance(s, str) else s for s in strides]
else:
strides = []
return _ffi_api.MatchBuffer( # type: ignore[attr-defined] # pylint: disable=no-member
Expand Down
1 change: 0 additions & 1 deletion python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1062,7 +1062,6 @@ def ptx_mma(
saturate : bool
The optional saturation at the output.


operator : Optional[Literal["xor", "and"]]
The 1-bit operator.

Expand Down
4 changes: 3 additions & 1 deletion src/tir/ir/expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,9 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
ICHECK(base.dtype().is_scalar());
ICHECK(stride.dtype().is_scalar());
ICHECK_GT(lanes, 1);
ICHECK_EQ(stride.dtype(), base.dtype());
if (stride.dtype() != base.dtype()) {
stride = cast(base.dtype(), stride);
}

ObjectPtr<RampNode> node = make_object<RampNode>();
node->dtype = base.dtype().with_lanes(lanes);
Expand Down
26 changes: 26 additions & 0 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3331,6 +3331,14 @@ def buffer_ramp_access(a: T.handle, b: T.handle, c: T.handle) -> None:
return buffer_ramp_access


def ramp_int64():
@T.prim_func
def func() -> None:
T.evaluate(T.Ramp(T.int64(0), 1, 3))

return func


def let_expression():
@T.prim_func
def func():
Expand All @@ -3346,6 +3354,7 @@ def test_void_ptr_vs_handle():
In the future, perhaps these should be de-duplicated by forbidding
one of the two C++ representations.
"""

# Generates PointerType(PrimType(DataType::Void()))
@T.prim_func
def void_ptr(out_ret_value: T.handle("void")):
Expand Down Expand Up @@ -3622,6 +3631,21 @@ def main(a: T.handle, b: T.handle):
return main


def string_stride_int64():
@T.prim_func
def main(a: T.handle, b: T.handle):
T.func_attr({"from_legacy_te_schedule": True, "global_symbol": "main", "tir.noalias": True})
n = T.int64()
A_s0 = T.int64()
B_s0 = T.int64()
A = T.match_buffer(a, (n,), strides=(A_s0,), buffer_type="auto")
B = T.match_buffer(b, (n,), strides=(B_s0,), buffer_type="auto")
for i in range(n):
B[i] = A[i]

return main


def merge_shape_var_def():
@T.prim_func
def main(A: T.handle, B: T.handle):
Expand Down Expand Up @@ -4013,6 +4037,7 @@ def func():
pointer_type,
buffer_axis_separator,
buffer_ramp_access_as_slice_index,
ramp_int64,
let_expression,
void_ptr,
decl_buffer,
Expand All @@ -4035,6 +4060,7 @@ def func():
let_stmt_var,
let_stmt_value,
string_stride,
string_stride_int64,
merge_shape_var_def,
if_then_else_var,
tvm_shfl_builtins,
Expand Down