Skip to content

Commit 49de0f2

Browse files
committed
* test cases
1 parent 1c4f076 commit 49de0f2

File tree

2 files changed

+18
-30
lines changed

2 files changed

+18
-30
lines changed

python/tvm/relay/frontend/onnx.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def _impl_v1(cls, inputs, attr, params):
344344
m.run()
345345
params_new = m.get_output(0)
346346
inputs.pop(1)
347-
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().flatten()))
347+
out = _op.reshape(inputs[0], tuple(params_new.asnumpy().astype('int32').flatten()))
348348

349349
return out
350350

@@ -483,20 +483,7 @@ class Shape(OnnxOpConverter):
483483

484484
@classmethod
485485
def _impl_v1(cls, inputs, attr, params):
486-
from topi.util import get_const_tuple
487-
try:
488-
out_type = ir_pass.infer_type(inputs[0])
489-
out_shape = get_const_tuple(out_type.checked_type.shape)
490-
except ValueError as e:
491-
raise ImportError(
492-
"Please pass graph level shapes to compute shape node properly {}".format(e))
493-
494-
node_name = attr['tvm_custom']['name']
495-
params[node_name] = _nd.array(np.asarray(out_shape, dtype='int64'))
496-
497-
return _expr.var(node_name,
498-
shape=params[node_name].shape,
499-
dtype=params[node_name].dtype)
486+
return _op.shape_of(inputs[0])
500487

501488
class Cast(OnnxOpConverter):
502489
""" Operator converter for Cast.

tests/python/frontend/onnx/test_forward.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -113,35 +113,36 @@ def test_reshape():
113113

114114
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
115115

116-
def test_reshape_like():
116+
def test_shape():
117117
in_shape = (4, 3, 3, 4)
118-
ref_shape = (3, 4, 4, 3)
118+
ref_shape = (6, 2, 4, 3)
119119

120-
ref_array = np.random.uniform(size=ref_shape).astype('float32')
120+
ref_array = np.array(ref_shape)
121121
ref_node = onnx.helper.make_node('Constant',
122122
inputs=[],
123123
outputs=['ref_in'],
124124
value=onnx.helper.make_tensor(name = 'const_tensor',
125-
data_type = onnx.TensorProto.FLOAT,
125+
data_type = onnx.TensorProto.INT32,
126126
dims = ref_array.shape,
127-
vals = ref_array.flatten().astype(float)))
128-
copy_node = helper.make_node("Identity", ["ref_in"], ["copy_in"])
129-
reshape_node = helper.make_node("Reshape", ["in", "copy_in"], ["out"])
127+
vals = ref_array.flatten().astype(int)))
128+
reshape_node = helper.make_node("Reshape", ["in", "ref_in"], ["out"])
129+
130+
shape_node = helper.make_node("Shape", ['out'], ['final_out'])
130131

131-
graph = helper.make_graph([ref_node, copy_node, reshape_node],
132-
"reshape_like_test",
132+
graph = helper.make_graph([ref_node, reshape_node, shape_node],
133+
"shape_test",
133134
inputs = [helper.make_tensor_value_info("in",
134135
TensorProto.FLOAT, list(in_shape))],
135-
outputs = [helper.make_tensor_value_info("out",
136+
outputs = [helper.make_tensor_value_info("final_out",
136137
TensorProto.FLOAT, list(ref_shape))])
137138

138-
model = helper.make_model(graph, producer_name='reshape_like_test')
139+
model = helper.make_model(graph, producer_name='shape_test')
139140

140141
for target, ctx in ctx_list():
141-
x = np.random.uniform(size=in_shape).astype('float32')
142-
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'float32')
142+
x = np.random.uniform(size=in_shape).astype('int32')
143+
tvm_out = get_tvm_output(model, x, target, ctx, ref_shape, 'int32')
143144

144-
tvm.testing.assert_allclose(ref_shape, tvm_out.shape)
145+
tvm.testing.assert_allclose(ref_shape, tvm_out)
145146

146147
def _test_power_iteration(x_shape, y_shape):
147148
if isinstance(y_shape, int):
@@ -995,7 +996,7 @@ def test_LogSoftmax():
995996

996997
if __name__ == '__main__':
997998
test_reshape()
998-
test_reshape_like()
999+
test_shape()
9991000
test_power()
10001001
test_squeeze()
10011002
test_unsqueeze()

0 commit comments

Comments
 (0)