@@ -631,6 +631,40 @@ def get_graph(x_shape, new_shape):
631
631
run_and_verify_func (get_graph ((1 , 1 , 2 , 3 ), (1 , 6 )))
632
632
633
633
634
+ def test_dynamic_reshape ():
635
+ if skip_codegen_test ():
636
+ return
637
+
638
+ def test_run (batches_to_test , x_shape , new_shape ):
639
+ x_data = np .ones ([max (batches_to_test )] + list (x_shape )[1 :]).astype ("float32" )
640
+ result_arr = [{} for _ in range (len (batches_to_test ))]
641
+ for use_trt in [True ]:
642
+ x = relay .var ("x" , shape = x_shape , dtype = "float32" )
643
+ out = relay .reshape (x , new_shape )
644
+ f = relay .Function ([x ], out )
645
+ mod = tvm .IRModule ()
646
+ mod ["main" ] = f
647
+ if use_trt :
648
+ mod , _ = tensorrt .partition_for_tensorrt (mod , params = {})
649
+ print (mod )
650
+ if not skip_runtime_test ():
651
+ with relay .build_config (opt_level = 3 ):
652
+ relay_exec = relay .create_executor ("vm" , mod = mod , ctx = tvm .cpu (0 ), target = "llvm" )
653
+
654
+ for i , batch_size in enumerate (batches_to_test ):
655
+ result_arr [i ][use_trt ] = relay_exec .evaluate ()(x_data [:batch_size , ...])
656
+ print (x_data [:batch_size , ...].shape , result_arr [i ][use_trt ].shape )
657
+
658
+ if not skip_runtime_test ():
659
+ for i in range (len (batches_to_test )):
660
+ assert_result_dict_holds (result_arr [i ])
661
+
662
+ batches_to_test = [1 , 1 , 0 , 2 , 3 , 0 , 1 , 3 , 2 ]
663
+ x_shape = (relay .Any (), 3 , 2 , 3 )
664
+ new_shape = (- 1 , 1 , 2 , 3 )
665
+ test_run (batches_to_test , x_shape , new_shape )
666
+
667
+
634
668
def test_transpose ():
635
669
def get_graph (x_shape , order ):
636
670
x = relay .var ("x" , shape = (x_shape ), dtype = "float32" )
0 commit comments