1515# specific language governing permissions and limitations
1616# under the License.
1717"""Developer API of constructing Relax AST."""
18- from typing import List , Optional , Union , Dict
18+ import typing
19+ from typing import List , Optional , Union , Dict , Any , Callable
1920from tvm .relay .expr import Tuple
2021from tvm .runtime import Object
2122from tvm import relax as rx
23+ from tvm import tir
2224from .expr import *
25+ from .op .base import call_dps
2326from tvm ._ffi .base import _LIB , check_call
2427from . import _ffi_api
2528
@@ -72,7 +75,7 @@ class BlockBuilder(Object):
7275 dtype1 = rx.DynTensorType(rank=1, dtype="float16")
7376 x = rx.Var("x", [m, n], dtype0)
7477 y = rx.Var("y", [n], dtype1)
75- ib = rx.IRBuilder ()
78+ ib = rx.BlockBuilder ()
7679 with ib.function([x, y], "func"):
7780 with ib.dataflow() as df:
7881 lv0 = ib.emit(rx.add(x, y))
@@ -84,17 +87,69 @@ class BlockBuilder(Object):
8487
8588 def __init__ (self ):
8689 self ._blocks = []
90+ self ._context_mod = tvm .IRModule ()
8791 self .__init_handle_by_constructor__ (_ffi_api .BlockBuilderCreate )
8892
8993 def _begin_dataflow_block (self ) -> None :
9094 _ffi_api .BlockBuilderBeginDataflowBlock (self )
9195
9296 def _begin_binding_block (self ) -> None :
9397 _ffi_api .BlockBuilderBeginBindingBlock (self )
94-
98+
9599 def _end_block (self ) -> BindingBlock :
96100 return _ffi_api .BlockBuilderEndBlock (self )
97101
102+ def _convert_te_arg (self ,
103+ te_args : Any
104+ ) -> typing .Tuple [Any , List [tvm .te .Tensor ]]:
105+ """Helper function to convert Relax expressions to te tensor.
106+ In the common case, the type of te_args is a Relax expression and is converted into a te tensor.
107+ If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array),
108+ we recursive and convert any value of type Relax expression into a te tensor.
109+ Common values of type int, float, and str are preserved.
110+
111+ Parameters
112+ ----------
113+ te_args : Any
114+ Argument to convert to te
115+
116+ Returns
117+ -------
118+ ret : (Any, [tvm.te.Tensor])
119+ A tuple of the converted te_args, and a list of te tensors for each converted Relax expression
120+ """
121+ te_args_list = []
122+
123+ def _convert_te_arg_helper (arg ):
124+ if isinstance (arg , Expr ):
125+ arg = te_tensor (arg )
126+ te_args_list .append (arg )
127+ return arg
128+ elif isinstance (arg , (list , tvm .ir .Array )):
129+ return [_convert_te_arg_helper (x ) for x in arg ]
130+ elif isinstance (arg , tuple ):
131+ return tuple ([_convert_te_arg_helper (x ) for x in arg ])
132+ elif isinstance (arg , (dict , tvm .ir .Map )):
133+ for key in arg :
134+ assert isinstance (key , str ), "emit_te only supports dict with string as the key currently"
135+ return {k : _convert_te_arg_helper (arg [k ]) for k in arg }
136+ elif isinstance (arg , (int , float , str )):
137+ return arg
138+ else :
139+ raise TypeError ("not supported type in emit_te: {}" .format (type (arg )))
140+
141+ new_arg = _convert_te_arg_helper (te_args )
142+ return new_arg , te_args_list
143+
144+ def _check_te_args (self , args : List [tvm .te .Tensor ]):
145+ """check te arguments."""
146+ #TODO(hypercubestart, ziheng) support full dynamic shape in the future
147+ for x in args :
148+ for s in x .shape :
149+ if not isinstance (s , (tir .Var , tir .IntImm )):
150+ raise ValueError ("emit_te not support symbolic shape"
151+ "contains expression now: {}" .format (x .shape ))
152+
98153 def function (self ,
99154 params : Optional [Union [Var , Tuple , List [Var ]]] = None ,
100155 name : Optional [str ] = "" ) -> FunctionScope :
@@ -139,7 +194,7 @@ def emit(self, call: relay.Call) -> Var:
139194
140195 Parameters
141196 ----------
142- call : tvm.relay .Call
197+ call : tvm.relax .Call
143198 The call node to be emitted.
144199
145200 Returns
@@ -149,12 +204,97 @@ def emit(self, call: relay.Call) -> Var:
149204 """
150205 return _ffi_api .BlockBuilderEmit (self , call )
151206
207+ def emit_te (self , func : Callable , * args : Any , ** kwargs : Any ) -> Var :
208+ """Emit a call node according to the te function.
209+ This function converts arguments from relax expression to te tensor,
210+ The callback func should return a te tensor.
211+
212+ Parameters
213+ ----------
214+ func : Callable
215+ A function that return a te tensor.
216+
217+ Returns
218+ -------
219+ ret : tvm.relax.Var
220+ A newly created variable that gets binded to the call code.
221+
222+ Example
223+ -------
224+
225+ .. code-block:: python
226+
227+ bb = rx.BlockBuilder()
228+ n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
229+ type_anno = rx.DynTensorType(2, "float32")
230+ x = rx.Var("x", [n, m], type_anno)
231+ y = rx.Var("y", [n, m], type_anno)
232+
233+ def te_func(args, args_dict, msg):
234+ A = args[0]
235+ B = args_dict["B"]
236+ return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
237+
238+ with bb.function([x, y], "rx_func"):
239+ out = bb.emit_te(te_func, [x], {"B": y}, msg="hello")
240+ bb.emit_func_output(out)
241+
242+ will result in TVMScript
243+
244+ .. code-block:: python
245+
246+ @tvm.script.ir_module
247+ class Module:
248+ @T.prim_func
249+ def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle) -> None:
250+ # function attr dict
251+ T.func_attr({"global_symbol": "te_func"})
252+ m = T.var("int64")
253+ n = T.var("int64")
254+ rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32")
255+ rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32")
256+ compute = T.match_buffer(var_compute, [128, 128], dtype="float32")
257+ # body
258+ # with T.block("root")
259+ for i0, i1 in T.grid(128, 128):
260+ with T.block("compute"):
261+ i, j = T.axis.remap("SS", [i0, i1])
262+ T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]])
263+ T.writes([compute[i, j]])
264+ compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j]
265+
266+ @R.function
267+ def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tensor:
268+ # block 0
269+ gv = relax.call_dps((128, 128), "te_func", (x, y))
270+ return gv
271+ """
272+ new_args , te_arg_list = self ._convert_te_arg (args )
273+ new_kwargs , te_kwarg_list = self ._convert_te_arg (kwargs )
274+
275+ te_args = te_arg_list + te_kwarg_list
276+ self ._check_te_args (te_args )
277+
278+ # TODO(hypercubestart, ziheng) handle multiple output case
279+ te_out = func (* new_args , ** new_kwargs )
280+ assert isinstance (te_out , tvm .te .tensor .Tensor ), "only support te tensor as function output"
281+
282+ inputs = [* te_args , te_out ]
283+ tir_func = tvm .te .create_prim_func (inputs )
284+ func_name = _ffi_api .BlockBuilderGetUniqueName (self , func .__name__ )
285+ tir_func = tir_func .with_attr ("global_symbol" , func_name )
286+ gvar = GlobalVar (func_name )
287+ self ._context_mod [gvar ] = tir_func
288+ call = call_dps (inputs [- 1 ].shape , gvar , [x .op .value for x in inputs [:- 1 ]])
289+ return _ffi_api .BlockBuilderEmit (self , call )
290+
291+
152292 def match_shape (self , value : Expr , pattern : List [PrimExpr ]) -> Var :
153293 """Emit a MatchShape.
154294
155295 Parameters
156296 ----------
157- value : tvm.relay .Expr
297+ value : tvm.relax .Expr
158298 The value of the MatchShape to be emitted.
159299
160300 pattern : List[PrimExpr]
@@ -224,8 +364,19 @@ def get(self) -> Function:
224364 ret : tvm.relax.Function
225365 A Relax function node being built.
226366 """
367+ # TODO(hyoercubestart, ziheng) get should return IRModule with relax + TIR functions
227368 seqe = rx .SeqExpr (self ._blocks , self ._func_ret )
228369 func = rx .Function (
229370 self ._func_params , seqe , rx .DynTensorType (- 1 , "float32" ), rx .GlobalVar (self ._func_name )
230371 )
231372 return func
373+
374+ def context_mod (self ):
375+ """Return the context module that might contain tir functions.
376+
377+ Returns
378+ -------
379+ mod : tvm.IRModule
380+ The context module that contains tir functions during emit.
381+ """
382+ return self ._context_mod
0 commit comments