Skip to content

Commit 5a53d0f

Browse files
junrushaoylc
authored andcommitted
[TensorIR][Minor] Allow Tuple/Array in TE lowering (apache#8916)
1 parent c14b853 commit 5a53d0f

File tree

1 file changed

+4
-3
lines changed

1 file changed

+4
-3
lines changed

python/tvm/te/operation.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,13 @@
2222
import tvm._ffi
2323
import tvm.tir
2424
import tvm.tir._ffi_api
25-
2625
from tvm._ffi.base import string_types
26+
from tvm.ir import Array
2727
from tvm.runtime import convert
2828

29+
from . import _ffi_api
2930
from . import tag as _tag
3031
from . import tensor as _tensor
31-
from . import _ffi_api
3232

3333

3434
def placeholder(shape, dtype=None, name="placeholder"):
@@ -431,6 +431,7 @@ def reduce_axis(dom, name="rv", thread_tag="", span=None):
431431

432432
def create_prim_func(ops: List[_tensor.Tensor]) -> tvm.tir.PrimFunc:
433433
"""Create a TensorIR PrimFunc from tensor expression
434+
434435
Parameters
435436
----------
436437
ops : List[Tensor]
@@ -473,6 +474,6 @@ def tir_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None:
473474
func : tir.PrimFunc
474475
The created function.
475476
"""
476-
if not isinstance(ops, list):
477+
if not isinstance(ops, (list, tuple, Array)):
477478
ops = [ops]
478479
return _ffi_api.CreatePrimFunc(ops)

0 commit comments

Comments
 (0)