@@ -54,6 +54,7 @@ def check_result(
54
54
for kind in ["debug" , "vm" ]:
55
55
targets = targets or tvm .testing .enabled_targets ()
56
56
for tgt , ctx in targets :
57
+ print (tgt )
57
58
if disable_targets and tgt in disable_targets :
58
59
continue
59
60
if kind == "debug" and (only_vm or ctx .device_type != tvm .cpu ().device_type ):
@@ -199,6 +200,15 @@ def test_any_concat():
199
200
ref = np .concatenate ([x_np - 3.0 , y_np * 5.0 ], axis = 0 )
200
201
check_result ([x_np , y_np ], mod , ref )
201
202
203
+ num_inputs = 25
204
+ x = [relay .var ("x" , shape = (relay .Any (),), dtype = "float32" ) for _ in range (num_inputs )]
205
+ z = relay .op .concatenate (x , axis = 0 )
206
+ mod = tvm .IRModule ()
207
+ mod ["main" ] = relay .Function (x , z )
208
+ x_np = [np .random .uniform (size = (1 ,)).astype ("float32" ) for _ in range (num_inputs )]
209
+ ref = np .concatenate (x_np , axis = 0 )
210
+ check_result (x_np , mod , ref )
211
+
202
212
203
213
def verify_any_reshape (x_shape , newshape , x_np_shape , out_shape , variable_newshape = False ):
204
214
x = relay .var ("x" , shape = x_shape , dtype = "float32" )
@@ -572,9 +582,7 @@ def verify_any_conv2d_transpose_nchw(
572
582
mod ["main" ] = relay .Function ([data , kernel ], y )
573
583
data_np = np .random .uniform (size = static_data_shape ).astype (dtype )
574
584
kernel_np = np .random .uniform (size = kernel_shape ).astype (dtype )
575
- check_result (
576
- [data_np , kernel_np ], mod , ref_out_shape , assert_shape = True , targets = [("llvm" , tvm .cpu ())]
577
- )
585
+ check_result ([data_np , kernel_np ], mod , ref_out_shape , assert_shape = True )
578
586
579
587
580
588
# TODO(@kevinthesun): Support dynamic input height and width.
@@ -1430,6 +1438,21 @@ def test_non_max_suppression():
1430
1438
disable_targets = ["nvptx" ],
1431
1439
)
1432
1440
1441
+ np_data = np .zeros ((1 , 0 , 6 )).astype ("float32" )
1442
+ np_valid_count = np .array ([0 ]).astype ("int32" )
1443
+ np_indices = np .zeros ((1 , 0 )).astype ("int32" )
1444
+ np_max_output_size = - 1
1445
+ np_indices_result = np .zeros ((1 , 0 ))
1446
+ np_valid_box_count = np .array ([[0 ]]).astype ("int32" )
1447
+
1448
+ check_result (
1449
+ [np_data , np_valid_count , np_indices , np_max_output_size ],
1450
+ mod ,
1451
+ [np_indices_result , np_valid_box_count ],
1452
+ only_vm = False ,
1453
+ disable_targets = ["nvptx" ],
1454
+ )
1455
+
1433
1456
1434
1457
if __name__ == "__main__" :
1435
1458
pytest .main ([__file__ ])
0 commit comments