2020from tvm import tir , te
2121from tvm import relay
2222from tvm import relax as rx
23- import numpy as np
2423
2524from tvm .ir .base import assert_structural_equal
2625from tvm .relax import op
@@ -61,7 +60,7 @@ def test_function_single_block():
6160 y = rx .Var ("y" , [n ], dtype1 )
6261 ib = rx .BlockBuilder ()
6362
64- with ib .function ([x , y ]):
63+ with ib .function ([x , y ], "func" ):
6564 with ib .dataflow () as df :
6665 lv0 = ib .emit (rx .op .add (x , y ))
6766 assert lv0 .name_hint == "lv"
@@ -71,7 +70,7 @@ def test_function_single_block():
7170 assert gv0 .name_hint == "gv"
7271 ib .emit_func_output (gv0 )
7372
74- func = ib .get ()
73+ func = ib .get ()[ "func" ]
7574 assert func .params [0 ] == x
7675 assert func .params [1 ] == y
7776 assert func .body .body == gv0
@@ -106,7 +105,7 @@ def test_function_multi_blocks():
106105 gv2 = ib .emit_output (gv1 )
107106 ib .emit_func_output (gv2 )
108107
109- func = ib .get ()
108+ func = ib .get ()[ "func" ]
110109 assert gv2 .shape [0 ] == m
111110 assert gv2 .shape [1 ] == n
112111 assert gv2 .checked_type .rank == 2
@@ -121,6 +120,40 @@ def test_function_multi_blocks():
121120 assert len (func .body .blocks [2 ].bindings ) == 2
122121
123122
123+ def test_multi_functions ():
124+ m = tir .Var ("m" , "int32" )
125+ n = tir .Var ("n" , "int32" )
126+ dtype0 = rx .DynTensorType (rank = 2 , dtype = "float16" )
127+ dtype1 = rx .DynTensorType (rank = 1 , dtype = "float16" )
128+ x = rx .Var ("x" , [m , n ], dtype0 )
129+ y = rx .Var ("y" , [n ], dtype1 )
130+ ib = rx .BlockBuilder ()
131+
132+ with ib .function ([x , y ], "func1" ):
133+ with ib .dataflow () as df :
134+ lv0 = ib .emit (rx .op .add (x , y ))
135+ assert lv0 .name_hint == "lv"
136+ gv0 = ib .emit_output (lv0 )
137+ ib .emit_func_output (gv0 )
138+
139+ with ib .function ([x , y ], "func2" ):
140+ with ib .dataflow () as df :
141+ lv0 = ib .emit (rx .op .add (x , y ))
142+ assert lv0 .name_hint == "lv"
143+ gv0 = ib .emit_output (lv0 )
144+ ib .emit_func_output (gv0 )
145+
146+ mod = ib .get ()
147+ func1 = mod ["func1" ]
148+ assert func1 .params [0 ] == x
149+ assert func1 .params [1 ] == y
150+ assert func1 .name .name_hint == "func1"
151+ func2 = mod ["func2" ]
152+ assert func2 .params [0 ] == x
153+ assert func2 .params [1 ] == y
154+ assert func2 .name .name_hint == "func2"
155+
156+
124157def test_binary_shape_type_deduction ():
125158 m = tir .Var ("m" , "int32" )
126159 n = tir .Var ("n" , "int32" )
@@ -177,7 +210,7 @@ def test_emit_match_shape():
177210 y = rx .Var ("shape_value" , type_annotation = rx .ShapeType (), shape_annotation = shape_anno )
178211 ib = rx .BlockBuilder ()
179212
180- with ib .function ([x , y ]):
213+ with ib .function ([x , y ], "func" ):
181214 with ib .dataflow () as df :
182215 # lv0: Tensor[(m, n), "float32"] =
183216 # match_shape(x: Tensor[_, "float32"], [m, n])
@@ -194,7 +227,7 @@ def test_emit_match_shape():
194227 gv0 = ib .emit_output (lv1 )
195228
196229 ib .emit_func_output (gv0 )
197- func = ib .get ()
230+ func = ib .get ()[ "func" ]
198231 block = func .body .blocks [0 ]
199232 b0 , b1 = block .bindings [:2 ]
200233 assert isinstance (b0 , rx .MatchShape )
@@ -248,11 +281,8 @@ def te_func(args, args_dict, msg):
248281 out = bb .emit_te (te_func , [x , y ], {"C" : z }, msg = "hello" )
249282 bb .emit_func_output (out )
250283
251- func = bb .get ()
252- mod = bb .context_mod ()
253-
254- gvar = tvm .relay .GlobalVar ("rx_func" )
255- mod [gvar ] = func
284+ mod = bb .get ()
285+ rx_func = mod ["rx_func" ]
256286
257287 def get_tir_func ():
258288 A = te .placeholder ((n , m ), dtype = "float32" , name = "A" )
@@ -265,20 +295,20 @@ def get_tir_func():
265295 assert_structural_equal (mod ["te_func" ].body , get_tir_func ().body )
266296
267297 # check Relax function calls TIR function with call_dps call
268- assert func .params [0 ] == x
269- assert func .params [1 ] == y
270- assert func .params [2 ] == z
271- assert func .name .name_hint == "rx_func"
272- assert func .body .body == out
273- assert len (func .body .blocks ) == 1
274- assert len (func .body .blocks [0 ].bindings ) == 1
275- assert isinstance (func .body .blocks [0 ].bindings [0 ].value , rx .Call )
276- assert func .body .blocks [0 ].bindings [0 ].value .op == relay .op .get ("relax.call_dps" )
277- assert len (func .body .blocks [0 ].bindings [0 ].value .args ) == 3
278- assert func .body .blocks [0 ].bindings [0 ].value .args [1 ].name_hint == "te_func"
279- assert func .body .blocks [0 ].bindings [0 ].value .args [2 ][0 ] == x
280- assert func .body .blocks [0 ].bindings [0 ].value .args [2 ][1 ] == y
281- assert func .body .blocks [0 ].bindings [0 ].value .args [2 ][2 ] == z
298+ assert rx_func .params [0 ] == x
299+ assert rx_func .params [1 ] == y
300+ assert rx_func .params [2 ] == z
301+ assert rx_func .name .name_hint == "rx_func"
302+ assert rx_func .body .body == out
303+ assert len (rx_func .body .blocks ) == 1
304+ assert len (rx_func .body .blocks [0 ].bindings ) == 1
305+ assert isinstance (rx_func .body .blocks [0 ].bindings [0 ].value , rx .Call )
306+ assert rx_func .body .blocks [0 ].bindings [0 ].value .op == relay .op .get ("relax.call_dps" )
307+ assert len (rx_func .body .blocks [0 ].bindings [0 ].value .args ) == 3
308+ assert rx_func .body .blocks [0 ].bindings [0 ].value .args [1 ].name_hint == "te_func"
309+ assert rx_func .body .blocks [0 ].bindings [0 ].value .args [2 ][0 ] == x
310+ assert rx_func .body .blocks [0 ].bindings [0 ].value .args [2 ][1 ] == y
311+ assert rx_func .body .blocks [0 ].bindings [0 ].value .args [2 ][2 ] == z
282312
283313
284314def test_emit_te_multiple ():
@@ -297,16 +327,45 @@ def te_func(A):
297327 y1 = bb .emit_te (te_func , y )
298328 bb .emit_func_output (y1 )
299329
300- func = bb .get ()
330+ func = bb .get ()[ "rx_func" ]
301331 assert func .body .blocks [0 ].bindings [0 ].value .args [1 ].name_hint == "te_func"
302332 assert func .body .blocks [0 ].bindings [1 ].value .args [1 ].name_hint == "te_func1"
303333
334+
335+ def test_emit_te_extern ():
336+ bb = rx .BlockBuilder ()
337+ n , m = tir .Var ("n" , "int64" ), tir .Var ("m" , "int64" )
338+ type_anno = rx .DynTensorType (2 , "float32" )
339+ x = rx .Var ("x" , [n , m ], type_anno )
340+ y = rx .Var ("y" , [m , n ], type_anno )
341+
342+ with bb .function ([x , y ], "rx_cblas_matmul" ):
343+ out = bb .emit_te (tvm .contrib .cblas .matmul , x , y , transa = False , transb = False )
344+ bb .emit_func_output (out )
345+
346+ mod = bb .get ()
347+ rx_func = mod ["rx_cblas_matmul" ]
348+
349+ # check Relax function calls TIR function with call_dps call
350+ assert rx_func .params [0 ] == x
351+ assert rx_func .params [1 ] == y
352+ assert len (rx_func .body .blocks ) == 1
353+ assert isinstance (rx_func .body .blocks [0 ].bindings [0 ].value , rx .Call )
354+ assert rx_func .body .blocks [0 ].bindings [0 ].value .op == relay .op .get ("relax.call_dps" )
355+ assert len (rx_func .body .blocks [0 ].bindings [0 ].value .args ) == 3
356+ assert rx_func .body .blocks [0 ].bindings [0 ].value .args [1 ].name_hint == "matmul"
357+ assert rx_func .body .blocks [0 ].bindings [0 ].value .args [2 ][0 ] == x
358+ assert rx_func .body .blocks [0 ].bindings [0 ].value .args [2 ][1 ] == y
359+
360+
304361if __name__ == "__main__" :
305362 test_block_builder ()
306363 test_function_single_block ()
307364 test_function_multi_blocks ()
365+ test_multi_functions ()
308366 test_binary_shape_type_deduction ()
309367 test_emit_match_shape ()
310368 test_normalize ()
311369 test_emit_te ()
312370 test_emit_te_multiple ()
371+ test_emit_te_extern ()
0 commit comments