Skip to content

Commit 837b1e5

Browse files
committed
fix tensor array ops
1 parent 7b458bf commit 837b1e5

File tree

1 file changed

+11
-16
lines changed

1 file changed

+11
-16
lines changed

tests/python/relay/test_adt.py

Lines changed: 11 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from tvm import relay
1919
from tvm.relay.backend.interpreter import ConstructorValue
2020
from tvm.relay import create_executor
21-
from tvm.relay.prelude import Prelude, TensorArrayOps
21+
from tvm.relay.prelude import Prelude
2222
from tvm.relay.testing import add_nat_definitions, count as count_, make_nat_value, make_nat_expr
2323

2424
import numpy as np
@@ -708,8 +708,7 @@ def run(dtype):
708708
x = relay.var('x')
709709
mod = relay.Module()
710710
p = Prelude(mod)
711-
tensor_array_ops = TensorArrayOps(p, dtype)
712-
tensor_array = tensor_array_ops.get_var('tensor_array')
711+
tensor_array = p.get_var('tensor_array', dtype)
713712
mod["main"] = relay.Function([x], tensor_array(x))
714713
for kind in ["debug"]:
715714
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
@@ -724,11 +723,10 @@ def test_tensor_array_read():
724723
def run(dtype):
725724
mod = relay.Module()
726725
p = Prelude(mod)
727-
tensor_array_ops = TensorArrayOps(p, dtype)
728726
l = relay.var('l')
729727
i = relay.var('i')
730-
read_func = tensor_array_ops.get_var('tensor_array_read')
731-
tensor_array = tensor_array_ops.get_var('tensor_array')
728+
read_func = p.get_var('tensor_array_read', dtype)
729+
tensor_array = p.get_var('tensor_array', dtype)
732730
mod["main"] = relay.Function([l, i], read_func(tensor_array(l), i))
733731
for kind in ["debug"]:
734732
ex = relay.create_executor(kind, mod=mod, ctx=tvm.cpu(), target="llvm")
@@ -770,11 +768,10 @@ def test_tensor_array_stack():
770768
def run(dtype):
771769
mod = relay.Module()
772770
p = Prelude(mod)
773-
tensor_array_ops = TensorArrayOps(p, dtype)
774-
tensor_array = tensor_array_ops.get_var('tensor_array')
775-
tensor1 = tensor_array_ops.get_var('tensor1')
776-
write = tensor_array_ops.get_var('tensor_array_write')
777-
stack = tensor_array_ops.get_var('tensor_array_stack')
771+
tensor_array = p.get_var('tensor_array', dtype)
772+
tensor1 = p.get_var('tensor1', dtype)
773+
write = p.get_var('tensor_array_write', dtype)
774+
stack = p.get_var('tensor_array_stack', dtype)
778775
l = relay.var('l')
779776
v = relay.var('v')
780777
init_tensor_array = tensor_array(relay.const(3))
@@ -797,8 +794,7 @@ def test_tensor_array_unstack():
797794
def run(dtype):
798795
mod = relay.Module()
799796
p = Prelude(mod)
800-
tensor_array_ops = TensorArrayOps(p, dtype)
801-
unstack_tensor1 = tensor_array_ops.get_var('tensor_array_unstack_tensor1')
797+
unstack_tensor1 = p.get_var('tensor_array_unstack_tensor1', dtype)
802798
v = relay.var('v')
803799
mod["main"] = relay.Function([v], unstack_tensor1(v))
804800
for kind in ["debug"]:
@@ -814,9 +810,8 @@ def test_tensor_take():
814810
def run(dtype):
815811
mod = relay.Module()
816812
p = Prelude(mod)
817-
tensor_array_ops = TensorArrayOps(p, dtype)
818-
take = tensor_array_ops.get_var('tensor_take')
819-
tensor2 = tensor_array_ops.get_var('tensor2')
813+
take = p.get_var('tensor_take', dtype)
814+
tensor2 = p.get_var('tensor2', dtype)
820815
v = relay.var('v')
821816
lower = relay.var('lower')
822817
upper = relay.var('upper')

0 commit comments

Comments
 (0)