Skip to content

Commit

Permalink
[Fix] Refactor the roundtrip test. (apache#10592)
Browse files Browse the repository at this point in the history
This is a tiny fix on the roundtrip test, the case test I introduced in apache#10370 doesn't use `tvm.testing.parameter`.
  • Loading branch information
yzh119 authored Mar 13, 2022
1 parent ce2f81a commit 5775f64
Showing 1 changed file with 26 additions and 26 deletions.
52 changes: 26 additions & 26 deletions tests/python/unittest/test_tvmscript_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -3177,6 +3177,31 @@ def ctpop(A: T.Buffer[(16,), "uint8"], B: T.Buffer[(16,), "uint8"]) -> None:
return ctpop


def parse_bufferslice_as_range_bound():
@T.prim_func
def segment_sum(
A_ptr: T.handle, B_ptr: T.handle, indptr_ptr: T.handle, n: T.int32, m: T.int32
) -> None:
A = T.match_buffer(A_ptr, [m], dtype="float32")
B = T.match_buffer(B_ptr, [n], dtype="float32")
indptr = T.match_buffer(indptr_ptr, [n + 1], dtype="int32")
for i in T.serial(n):
with T.block("outer"):
vi = T.axis.spatial(n, i)
T.reads(indptr[i : i + 2], B[vi], A[indptr[i] : indptr[i + 1]])
T.writes(B[vi])
for j in T.serial(indptr[i], indptr[i + 1]):
with T.block("inner"):
vj = T.axis.reduce(m, j)
T.reads(B[vi], A[vj])
T.writes(B[vi])
with T.init():
B[vi] = T.float32(0)
B[vi] = B[vi] + A[vj]

return segment_sum


ir_generator = tvm.testing.parameter(
opt_gemm_normalize,
opt_gemm_lower,
Expand Down Expand Up @@ -3208,6 +3233,7 @@ def ctpop(A: T.Buffer[(16,), "uint8"], B: T.Buffer[(16,), "uint8"]) -> None:
func_T_ptr_let_statement,
func_T_ptr_allocate,
llvm_intrin_call,
parse_bufferslice_as_range_bound,
)


Expand All @@ -3217,31 +3243,5 @@ def test_roundtrip(ir_generator):
tvm.ir.assert_structural_equal(original, after_roundtrip, True)


@T.prim_func
def segment_sum(
A_ptr: T.handle, B_ptr: T.handle, indptr_ptr: T.handle, n: T.int32, m: T.int32
) -> None:
A = T.match_buffer(A_ptr, [m], dtype="float32")
B = T.match_buffer(B_ptr, [n], dtype="float32")
indptr = T.match_buffer(indptr_ptr, [n + 1], dtype="int32")
for i in T.serial(n):
with T.block("outer"):
vi = T.axis.spatial(n, i)
T.reads(indptr[i : i + 2], B[vi], A[indptr[i] : indptr[i + 1]])
T.writes(B[vi])
for j in T.serial(indptr[i], indptr[i + 1]):
with T.block("inner"):
vj = T.axis.reduce(m, j)
T.reads(B[vi], A[vj])
T.writes(B[vi])
with T.init():
B[vi] = T.float32(0)
B[vi] = B[vi] + A[vj]


def test_parse_bufferslice_as_range_bound():
tvm.ir.assert_structural_equal(segment_sum, tvm.script.from_source(segment_sum.script()))


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit 5775f64

Please sign in to comment.