Skip to content

Commit 304aa1e

Browse files
author
Krzysztof Parzyszek
authored
[TIR] Allow starred expressions in TIR script (#15404)
Small change in the evaluator to allow it to handle starred expressions (i.e. list/tuple splicing).
1 parent 9ff74fb commit 304aa1e

File tree

2 files changed

+29
-0
lines changed

2 files changed

+29
-0
lines changed

python/tvm/script/parser/core/evaluator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,17 @@ def _visit(self, node: doc.AST) -> Any:
221221
return node
222222
if isinstance(node, doc.Lambda):
223223
return self._eval_lambda(node)
224+
if isinstance(node, doc.Starred):
225+
value = self._visit(node.value)
226+
return doc.Starred(
227+
value=value,
228+
ctx=node.ctx,
229+
lineno=node.lineno,
230+
col_offset=node.col_offset,
231+
end_lineno=node.end_lineno,
232+
end_col_offset=node.end_col_offset,
233+
)
234+
224235
fields = {}
225236
for field in node.__class__._FIELDS: # pylint: disable=protected-access
226237
attr = getattr(node, field)

tests/python/unittest/test_tvmscript_parser_tir.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,5 +212,23 @@ def expected_non_hygienic(A: T.Buffer((1024,), "int32"), B: T.Buffer((), "int32"
212212
tvm.ir.assert_structural_equal(use_non_hygienic, expected_non_hygienic)
213213

214214

215+
def test_tir_starred_expression():
216+
dims = (128, 128)
217+
218+
@T.prim_func(private=True)
219+
def starred(a: T.handle) -> None:
220+
A = T.match_buffer(a, [128, *dims], "int32")
221+
for i, j, k in T.grid(128, *dims):
222+
A[i, j, k] = T.int32(1)
223+
224+
@T.prim_func(private=True)
225+
def non_starred(a: T.handle) -> None:
226+
A = T.match_buffer(a, [128, 128, 128], "int32")
227+
for i, j, k in T.grid(128, 128, 128):
228+
A[i, j, k] = T.int32(1)
229+
230+
tvm.ir.assert_structural_equal(starred, non_starred)
231+
232+
215233
if __name__ == "__main__":
216234
tvm.testing.main()

0 commit comments

Comments
 (0)