Skip to content

Commit 4c4888b

Browse files
authored
[ManifestAlloc] Handle TupleType inputs in CheckReshapeOnly (#6776)
* Changes in CheckReshapeOnly to support TupleTypes as input This arises insed ManifestAllocPass inside relay.vm.compile * [ManifestAlloc] Handle TupleType inputs in CheckReshapeOnly
1 parent ad92efd commit 4c4888b

File tree

2 files changed

+21
-0
lines changed

2 files changed

+21
-0
lines changed

python/tvm/relay/transform/memory_alloc.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,11 @@ def visit_call(self, call):
8484
for arg in call.args:
8585
self.visit(arg)
8686

87+
def visit_var(self, var):
88+
var_type = var.checked_type
89+
if not isinstance(var_type, ty.TensorType):
90+
self.reshape_only = False
91+
8792

8893
def is_reshape_only(func):
8994
"""Check if the primitive function contains only reshape ops."""

tests/python/relay/test_vm.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -754,5 +754,21 @@ def test_vm_reshape_tensor():
754754
check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod)
755755

756756

757+
def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)):
758+
tup = relay.var(
759+
"tup",
760+
type_annotation=relay.TupleType([relay.TensorType(x_shape), relay.TensorType(y_shape)]),
761+
)
762+
out = relay.reshape(relay.TupleGetItem(tup, 0), (1, -1))
763+
f = relay.Function([tup], out)
764+
765+
x_data = np.random.uniform(size=x_shape).astype("float32")
766+
y_data = np.random.uniform(size=y_shape).astype("float32")
767+
768+
for tgt, ctx in tvm.testing.enabled_targets():
769+
res = veval(f, (x_data, y_data), ctx=ctx, target=tgt)
770+
tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1)))
771+
772+
757773
if __name__ == "__main__":
758774
pytest.main([__file__])

0 commit comments

Comments
 (0)