1919import numpy as np
2020import tvm
2121from tvm .relay import Call
22- from tvm import relax , tir
22+ from tvm import relax , tir , te
2323from tvm .runtime import container
2424import numpy as np
2525
26+ from tvm .ir .base import assert_structural_equal
2627import tvm .script
2728from tvm .script import tir as T , relax as R
2829
@@ -425,6 +426,98 @@ def test_vm_emit_te_extern():
425426 expected = np .dot (data .asnumpy (), weight .asnumpy ())
426427 np .testing .assert_allclose (expected , res .asnumpy (), rtol = 1e-4 , atol = 1e-4 )
427428
429+ def test_vm_emit_te_concat ():
430+ # concatenate of two vectors of size (n,) and (m,)
431+ bb = relax .BlockBuilder ()
432+ n , m = tir .Var ("n" , "int64" ), tir .Var ("m" , "int64" )
433+ type_anno = relax .DynTensorType (1 , "float32" )
434+ x = relax .Var ("x" , [n ], type_anno )
435+ y = relax .Var ("y" , [m ], type_anno )
436+
437+ def te_func (A , B ):
438+ C = te .compute ((n + m ), lambda i : tvm .tir .if_then_else (i < n , A [i ], B [i - n ]))
439+ return C
440+
441+ with bb .function ([x , y ], "rx_func" ):
442+ x1 = bb .emit_te (te_func , x , y )
443+ bb .emit_func_output (x1 )
444+
445+ mod = bb .get ()
446+
447+ target = tvm .target .Target ("llvm" )
448+ target_host = tvm .target .Target ("llvm" )
449+ ex , lib = relax .vm .build (mod , target , target_host )
450+
451+ vm = relax .VirtualMachine (ex , tvm .cpu (), mod = lib )
452+ inp = tvm .nd .array (np .random .rand (1 , ).astype (np .float32 ))
453+ inp2 = tvm .nd .array (np .random .rand (2 , ).astype (np .float32 ))
454+ res = vm ["rx_func" ](inp , inp2 )
455+
456+ np .testing .assert_allclose (res .asnumpy (), np .append (inp .asnumpy (), inp2 .asnumpy ()))
457+
458+ def test_vm_emit_te_floor_symbolic_shape ():
459+ bb = relax .BlockBuilder ()
460+ n = tir .Var ("n" , "int64" )
461+ type_anno = relax .DynTensorType (1 , "float32" )
462+ x = relax .Var ("x" , [n ], type_anno )
463+
464+ def te_func (A ):
465+ C = te .compute ((tir .floordiv (n , 2 ),), lambda i : A [i ] + 1 )
466+ return C
467+
468+ with bb .function ([x ], "rx_func" ):
469+ x1 = bb .emit_te (te_func , x )
470+ bb .emit_func_output (x1 )
471+
472+ mod = bb .get ()
473+
474+ target = tvm .target .Target ("llvm" )
475+ target_host = tvm .target .Target ("llvm" )
476+ ex , lib = relax .vm .build (mod , target , target_host )
477+
478+ vm = relax .VirtualMachine (ex , tvm .cpu (), mod = lib )
479+ shape = (9 , )
480+ inp = tvm .nd .array (np .random .rand (* shape ).astype (np .float32 ))
481+ res = vm ["rx_func" ](inp )
482+
483+ def expected_output ():
484+ output_shape = (shape [0 ] // 2 , )
485+ return inp .asnumpy ()[:output_shape [0 ]] + 1
486+
487+ np .testing .assert_allclose (res .asnumpy (), expected_output ())
488+
489+ def test_vm_relax_symbolic_shape ():
490+ bb = relax .BlockBuilder ()
491+ n = tir .Var ("n" , "int64" )
492+ type_anno = relax .DynTensorType (1 , "float32" )
493+ x = relax .Var ("x" , [n ], type_anno )
494+ y = relax .Var ("y" , [(n // 2 ) + 1 ], type_anno )
495+
496+ def te_func (A , B ):
497+ C = te .compute ((n , ), lambda i : A [i ] + B [i // 2 ])
498+ return C
499+
500+ with bb .function ([x , y ], "rx_func" ):
501+ x1 = bb .emit_te (te_func , x , y )
502+ bb .emit_func_output (x1 )
503+
504+ mod = bb .get ()
505+
506+ target = tvm .target .Target ("llvm" )
507+ target_host = tvm .target .Target ("llvm" )
508+ ex , lib = relax .vm .build (mod , target , target_host )
509+
510+ vm = relax .VirtualMachine (ex , tvm .cpu (), mod = lib )
511+ shape1 = (5 , )
512+ shape2 = (3 , )
513+ inp = tvm .nd .array (np .random .rand (* shape1 ).astype (np .float32 ))
514+ inp2 = tvm .nd .array (np .random .rand (* shape2 ).astype (np .float32 ))
515+ res = vm ["rx_func" ](inp , inp2 )
516+
517+ def expected_output ():
518+ return inp .asnumpy () + np .repeat (inp2 .asnumpy (), 2 )[:5 ]
519+
520+ np .testing .assert_allclose (res .asnumpy (), expected_output ())
428521
429522if __name__ == "__main__" :
430523 test_vm_execute ()
@@ -443,3 +536,6 @@ def test_vm_emit_te_extern():
443536 test_vm_compile_e2e ()
444537 test_vm_compile_e2e_func_param_with_shape ()
445538 test_vm_emit_te_extern ()
539+ test_vm_emit_te_concat ()
540+ test_vm_emit_te_floor_symbolic_shape ()
541+ test_vm_relax_symbolic_shape ()
0 commit comments