3030class FunctionScope (object ):
3131 """Auxiliary scope for function"""
3232
33- def __init__ (self , irbuilder ):
34- self ._ib = irbuilder
33+ def __init__ (self , block_builder , name , params ):
34+ self ._bb = block_builder
35+ self ._name = name
36+ self ._params = params
3537
3638 def __enter__ (self ):
37- _ffi_api . BlockBuilderBeginBindingBlock (self ._ib )
39+ self . _bb . _enter_function_scope (self ._name , self . _params )
3840
39- def __exit__ (self , ptype , value , trace ):
40- block = _ffi_api .BlockBuilderEndBlock (self ._ib )
41- if len (block .bindings ) > 0 :
42- self ._ib ._blocks .append (block )
43- seqe = rx .SeqExpr (self ._ib ._blocks , self ._ib ._func_ret )
44- func = rx .Function (
45- self ._ib ._func_params , seqe , rx .DynTensorType (- 1 , "float32" ), rx .GlobalVar (self ._ib ._func_name )
46- )
47- gvar = rx .GlobalVar (self ._ib ._func_name )
48- self ._ib ._context_mod [gvar ] = func
49- return func
41+ def __exit__ (self , exc_type , exc_val , exc_tb ):
42+ # __exit__ should properly handle the case where the with block exits with an exception
43+ # when handling error case in exit, always check if there is already an exception been thrown in the with block
44+ self ._bb ._exit_function_scope (exc_type , exc_val , exc_tb )
5045
5146
5247class DataflowScope (object ):
5348 """Auxiliary scope for Dataflow block"""
5449
55- def __init__ (self , irbuilder ):
56- self ._ib = irbuilder
50+ def __init__ (self , block_builder ):
51+ self ._bb = block_builder
5752
5853 def __enter__ (self ):
59- block = _ffi_api . BlockBuilderEndBlock ( self ._ib )
54+ block = self ._bb . _end_block ( )
6055 if len (block .bindings ) > 0 :
61- self ._ib ._blocks .append (block )
62- _ffi_api . BlockBuilderBeginDataflowBlock ( self ._ib )
56+ self ._bb ._blocks .append (block )
57+ self ._bb . _begin_dataflow_block ( )
6358
6459 def __exit__ (self , ptype , value , trace ):
65- block = _ffi_api . BlockBuilderEndBlock ( self ._ib )
60+ block = self ._bb . _end_block ( )
6661 if len (block .bindings ) > 0 :
67- self ._ib ._blocks .append (block )
68- _ffi_api . BlockBuilderBeginBindingBlock ( self ._ib )
62+ self ._bb ._blocks .append (block )
63+ self ._bb . _begin_binding_block ( )
6964
7065
7166@tvm ._ffi .register_object ("relax.BlockBuilder" )
@@ -82,19 +77,55 @@ class BlockBuilder(Object):
8277 dtype1 = rx.DynTensorType(rank=1, dtype="float16")
8378 x = rx.Var("x", [m, n], dtype0)
8479 y = rx.Var("y", [n], dtype1)
85- ib = rx.BlockBuilder()
86- with ib.function([x, y], "func"):
87- with ib.dataflow() as df:
88- lv0 = ib.emit(rx.add(x, y))
89- lv1 = ib.emit(rx.multiply(lv0, y))
90- gv0 = ib.emit_output(lv1)
91- ib.emit_func_output(gv0)
92- mod = ib.get()
80+ bb = rx.BlockBuilder()
81+ with bb.function([x, y], "func"):
82+ with bb.dataflow() as df:
83+ lv0 = bb.emit(rx.add(x, y))
84+ lv1 = bb.emit(rx.multiply(lv0, y))
85+ gv0 = bb.emit_output(lv1)
86+ bb.emit_func_output(gv0)
87+ mod = bb.get()
88+
89+ BlockBuilder can also be used to contruct neural networks with nn.Module API
90+
91+ .. code-block:: python
92+
93+ from tvm.relax.testing import nn
94+
95+ n = tir.Var("n", "int64")
96+ input_size = 784
97+ hidden_sizes = [128, 32]
98+ output_size = 10
99+ bb = rx.BlockBuilder()
100+
101+ with bb.function("main"):
102+ model = nn.Sequential(
103+ nn.Linear(input_size, hidden_sizes[0]),
104+ nn.ReLU(),
105+ nn.Linear(hidden_sizes[0], hidden_sizes[1]),
106+ nn.ReLU(),
107+ nn.Linear(hidden_sizes[1], output_size),
108+ nn.LogSoftmax(),
109+ )
110+ data = nn.Placeholder((n, input_size), name="data")
111+ output = model(data)
112+ params = [data] + model.parameters()
113+ builder.emit_func_output(output, params=params)
114+ mod = bb.get()
93115 """
94116
117+ _current = None
118+
119+ @staticmethod
120+ def current ():
121+ """Returns the current BlockBuilder."""
122+ return BlockBuilder ._current
123+
95124 def __init__ (self ):
96125 self ._blocks = []
97126 self ._context_mod = tvm .IRModule ()
127+ # a boolean flag that tracks if emit_func_output has been called
128+ self ._is_emit_func_output_called = False ;
98129 self .__init_handle_by_constructor__ (_ffi_api .BlockBuilderCreate )
99130
100131 def _begin_dataflow_block (self ) -> None :
@@ -105,6 +136,22 @@ def _begin_binding_block(self) -> None:
105136
106137 def _end_block (self ) -> BindingBlock :
107138 return _ffi_api .BlockBuilderEndBlock (self )
139+
140+ def _enter_function_scope (self , name , params ):
141+ if BlockBuilder .current () is not None :
142+ raise RuntimeError ("BlockBuilder does not allow nested functions." )
143+ BlockBuilder ._current = self
144+ self ._func_name = name
145+ self ._func_params = params
146+ self ._begin_binding_block ()
147+
148+ def _exit_function_scope (self , exc_type , exc_val , exc_tb ):
149+ if exc_type is None :
150+ if not self ._is_emit_func_output_called :
151+ raise RuntimeError ("emit_func_output must be called in a relax function." )
152+
153+ self ._is_emit_func_output_called = False
154+ BlockBuilder ._current = None
108155
109156 def _convert_te_arg (self ,
110157 te_args : Any
@@ -173,31 +220,36 @@ def _populate_used_vars(expr):
173220
174221
175222 def function (self ,
176- params : Optional [ Union [ Var , Tuple , List [ Var ]]] = None ,
177- name : Optional [str ] = "" ) -> FunctionScope :
223+ name : str ,
224+ params : Optional [Union [ Var , Tuple , List [ Var ]]] = None ) -> FunctionScope :
178225 """Annotate a Relax function.
179226
180227 Parameters
181228 ----------
229+ name : str, optional
230+ The name of the function
231+
182232 params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
183233 The parameters of the function.
184-
185- name : str, optional
186- The name of the function. If provided, the function is global, otherwise local.
234+ If params is None, it means deferring initialization of function parameters until emit_func_output.
187235
188236 Returns
189237 -------
190238 ret: FunctionScope
191239 A FunctionScope for building a Relax function node.
192240 """
193241 if not params :
194- params = []
195- if not isinstance (params , ( list , tuple ) ):
242+ params = None
243+ elif isinstance (params , rx . Var ):
196244 params = [params ]
245+ elif isinstance (params , (list , tuple )):
246+ for param in params :
247+ if not isinstance (param , rx .Var ):
248+ raise TypeError ("each element of function parameters must be of type tvm.relax.Var,\
249+ but got: {}" .format (type (param )))
197250
198- self ._func_params = params
199- self ._func_name = name
200- return FunctionScope (self )
251+ name = self .get_unique_name (name )
252+ return FunctionScope (self , name , params )
201253
202254 def dataflow (self ) -> DataflowScope :
203255 """Annotate a Relax dataflow block.
@@ -304,12 +356,12 @@ def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tenso
304356
305357 inputs = [* te_args , te_out ]
306358 tir_func = tvm .te .create_prim_func (inputs )
307- func_name = _ffi_api . BlockBuilderGetUniqueName ( self , func .__name__ )
359+ func_name = self . get_unique_name ( func .__name__ )
308360 tir_func = tir_func .with_attr ("global_symbol" , func_name )
309361 gvar = GlobalVar (func_name )
310362 self ._context_mod [gvar ] = tir_func
311363 call = call_dps (inputs [- 1 ].shape , gvar , [x .op .value for x in inputs [:- 1 ]])
312- return _ffi_api . BlockBuilderEmit ( self , call )
364+ return self . emit ( call )
313365
314366
315367 def match_shape (self , value : Expr , pattern : List [PrimExpr ]) -> Var :
@@ -347,22 +399,54 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
347399 output = Tuple (output )
348400 return _ffi_api .BlockBuilderEmitOutput (self , output )
349401
350- def emit_func_output (self , output : Union [Expr , Tuple , List [Expr ]]) -> None :
402+ def emit_func_output (self ,
403+ output : Union [Expr , Tuple , List [Expr ]],
404+ params : Optional [Union [Var , Tuple , List [Var ]]] = None ) -> None :
351405 """Emit output for the function.
352406
353407 Parameters
354408 ----------
355409 output : Expr | Tuple | List[Expr]
356410 The output of the current block/function.
411+
412+ params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
413+ The parameters of the function to be built.
414+ If params is None, it means the params have been initialized in the function with scope.
357415
358416 Returns
359417 -------
360418 ret : tvm.relax.Var
361419 The return variable which gets binded to the output.
362420 """
421+ if self ._is_emit_func_output_called :
422+ raise RuntimeError ("emit_func_output must be called exactly once in a relax function." )
423+ self ._is_emit_func_output_called = True
424+
425+ if self ._func_params is not None and params is not None :
426+ raise RuntimeError ("function parameters have been initialized in the function with scope." )
427+
428+ if self ._func_params is None and params is None :
429+ raise RuntimeError ("Relax function must have parameter." )
430+
431+ if self ._func_params is None :
432+ self ._func_params = params
433+
434+ if BlockBuilder .current () is not self :
435+ raise RuntimeError ("BlockBuilder._current must be self." )
436+
363437 if isinstance (output , (list , tuple )):
364438 output = Tuple (output )
365439 self ._func_ret = output
440+
441+ block = self ._end_block ()
442+ if len (block .bindings ) > 0 :
443+ self ._blocks .append (block )
444+ seqe = rx .SeqExpr (self ._blocks , self ._func_ret )
445+ func = rx .Function (
446+ self ._func_params , seqe , rx .DynTensorType (- 1 ), rx .GlobalVar (self ._func_name )
447+ )
448+ gvar = rx .GlobalVar (self ._func_name )
449+ self ._context_mod [gvar ] = func
366450
367451 def normalize (self , expr : Expr ) -> Expr :
368452 """Normalize an Expr to complete its shape and type.
@@ -388,3 +472,19 @@ def get(self) -> tvm.IRModule:
388472 An IRModule with Relax and TIR functions being built.
389473 """
390474 return self ._context_mod
475+
476+
477+ def get_unique_name (self , name_prefix : str ) -> str :
478+ """Generate a unique name with a specified prefix.
479+
480+ Parameters
481+ ----------
482+ name_hint : str
483+ The name prefix.
484+
485+ Returns
486+ -------
487+ ret : str
488+ The generated name.
489+ """
490+ return _ffi_api .BlockBuilderGetUniqueName (self , name_prefix )
0 commit comments