@@ -113,35 +113,36 @@ def test_reshape():
113113
114114 tvm .testing .assert_allclose (ref_shape , tvm_out .shape )
115115
116- def test_reshape_like ():
116+ def test_shape ():
117117 in_shape = (4 , 3 , 3 , 4 )
118- ref_shape = (3 , 4 , 4 , 3 )
118+ ref_shape = (6 , 2 , 4 , 3 )
119119
120- ref_array = np .random . uniform ( size = ref_shape ). astype ( 'float32' )
120+ ref_array = np .array ( ref_shape )
121121 ref_node = onnx .helper .make_node ('Constant' ,
122122 inputs = [],
123123 outputs = ['ref_in' ],
124124 value = onnx .helper .make_tensor (name = 'const_tensor' ,
125- data_type = onnx .TensorProto .FLOAT ,
125+ data_type = onnx .TensorProto .INT32 ,
126126 dims = ref_array .shape ,
127- vals = ref_array .flatten ().astype (float )))
128- copy_node = helper .make_node ("Identity" , ["ref_in" ], ["copy_in" ])
129- reshape_node = helper .make_node ("Reshape" , ["in" , "copy_in" ], ["out" ])
127+ vals = ref_array .flatten ().astype (int )))
128+ reshape_node = helper .make_node ("Reshape" , ["in" , "ref_in" ], ["out" ])
129+
130+ shape_node = helper .make_node ("Shape" , ['out' ], ['final_out' ])
130131
131- graph = helper .make_graph ([ref_node , copy_node , reshape_node ],
132- "reshape_like_test " ,
132+ graph = helper .make_graph ([ref_node , reshape_node , shape_node ],
133+ "shape_test " ,
133134 inputs = [helper .make_tensor_value_info ("in" ,
134135 TensorProto .FLOAT , list (in_shape ))],
135- outputs = [helper .make_tensor_value_info ("out " ,
136+ outputs = [helper .make_tensor_value_info ("final_out " ,
136137 TensorProto .FLOAT , list (ref_shape ))])
137138
138- model = helper .make_model (graph , producer_name = 'reshape_like_test ' )
139+ model = helper .make_model (graph , producer_name = 'shape_test ' )
139140
140141 for target , ctx in ctx_list ():
141- x = np .random .uniform (size = in_shape ).astype ('float32 ' )
142- tvm_out = get_tvm_output (model , x , target , ctx , ref_shape , 'float32 ' )
142+ x = np .random .uniform (size = in_shape ).astype ('int32 ' )
143+ tvm_out = get_tvm_output (model , x , target , ctx , ref_shape , 'int32 ' )
143144
144- tvm .testing .assert_allclose (ref_shape , tvm_out . shape )
145+ tvm .testing .assert_allclose (ref_shape , tvm_out )
145146
146147def _test_power_iteration (x_shape , y_shape ):
147148 if isinstance (y_shape , int ):
@@ -995,7 +996,7 @@ def test_LogSoftmax():
995996
996997if __name__ == '__main__' :
997998 test_reshape ()
998- test_reshape_like ()
999+ test_shape ()
9991000 test_power ()
10001001 test_squeeze ()
10011002 test_unsqueeze ()
0 commit comments