Skip to content

Commit

Permalink
[ManifestAlloc] Handle TupleType inputs in CheckReshapeOnly (apache#6776
Browse files Browse the repository at this point in the history
)

* Changes in CheckReshapeOnly to support TupleTypes as input

This arises insed ManifestAllocPass inside relay.vm.compile

* [ManifestAlloc] Handle TupleType inputs in CheckReshapeOnly
  • Loading branch information
rohanmukh authored and Trevor Morris committed Oct 28, 2020
1 parent 4f90394 commit 7e2daaf
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 0 deletions.
5 changes: 5 additions & 0 deletions python/tvm/relay/transform/memory_alloc.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,11 @@ def visit_call(self, call):
for arg in call.args:
self.visit(arg)

def visit_var(self, var):
var_type = var.checked_type
if not isinstance(var_type, ty.TensorType):
self.reshape_only = False


def is_reshape_only(func):
"""Check if the primitive function contains only reshape ops."""
Expand Down
16 changes: 16 additions & 0 deletions tests/python/relay/test_vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,5 +754,21 @@ def test_vm_reshape_tensor():
check_result([x_np, y_np], x_np.reshape([8, 2, 8]), mod)


def test_vm_reshape_tuple(x_shape=(1, 4, 2), y_shape=(1, 2, 10)):
tup = relay.var(
"tup",
type_annotation=relay.TupleType([relay.TensorType(x_shape), relay.TensorType(y_shape)]),
)
out = relay.reshape(relay.TupleGetItem(tup, 0), (1, -1))
f = relay.Function([tup], out)

x_data = np.random.uniform(size=x_shape).astype("float32")
y_data = np.random.uniform(size=y_shape).astype("float32")

for tgt, ctx in tvm.testing.enabled_targets():
res = veval(f, (x_data, y_data), ctx=ctx, target=tgt)
tvm.testing.assert_allclose(res.asnumpy(), np.reshape(x_data, (1, -1)))


if __name__ == "__main__":
pytest.main([__file__])

0 comments on commit 7e2daaf

Please sign in to comment.