Skip to content

Commit 92497dc

Browse files
ZihengJianghypercubestart
authored andcommitted
TE Integration (apache#36)
* Init. * Proof of concept. * Rebase on the newest branch * Move to emit_te * Update emit_te * Make RXPlaceholderOpNode as a subclass of PlaceholderOpNode * Update * run vm test_te * Update argument conversion * Reset create_primfunc * Update doc * Update test * Add error message * Update * Update * Address comment * unit test check structural and validate_te_args * raise ValueError when multiple outputs * address comments * example usage emit_te * Rename to context_mod * Handle multiple call * Address comments * Address comments * Use unique name * remove * rename args to te_args * address comments * fix TVMscript manually * spelling Co-authored-by: Andrew Liu <andrewlliu@gmail.com>
1 parent fdb3d71 commit 92497dc

File tree

9 files changed

+377
-10
lines changed

9 files changed

+377
-10
lines changed

include/tvm/te/operation.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,7 @@ class PlaceholderOpNode : public OperationNode {
182182
}
183183

184184
static constexpr const char* _type_key = "PlaceholderOp";
185-
TVM_DECLARE_FINAL_OBJECT_INFO(PlaceholderOpNode, OperationNode);
185+
TVM_DECLARE_BASE_OBJECT_INFO(PlaceholderOpNode, OperationNode);
186186
};
187187

188188
/*!

python/tvm/relax/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
# helper functions
5050
const = expr.const
5151
extern = expr.extern
52+
te_tensor = expr.te_tensor
5253

5354
# Type
5455
Type = ty.Type

python/tvm/relax/block_builder.py

Lines changed: 156 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -15,11 +15,14 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
"""Developer API of constructing Relax AST."""
18-
from typing import List, Optional, Union, Dict
18+
import typing
19+
from typing import List, Optional, Union, Dict, Any, Callable
1920
from tvm.relay.expr import Tuple
2021
from tvm.runtime import Object
2122
from tvm import relax as rx
23+
from tvm import tir
2224
from .expr import *
25+
from .op.base import call_dps
2326
from tvm._ffi.base import _LIB, check_call
2427
from . import _ffi_api
2528

@@ -72,7 +75,7 @@ class BlockBuilder(Object):
7275
dtype1 = rx.DynTensorType(rank=1, dtype="float16")
7376
x = rx.Var("x", [m, n], dtype0)
7477
y = rx.Var("y", [n], dtype1)
75-
ib = rx.IRBuilder()
78+
ib = rx.BlockBuilder()
7679
with ib.function([x, y], "func"):
7780
with ib.dataflow() as df:
7881
lv0 = ib.emit(rx.add(x, y))
@@ -84,17 +87,69 @@ class BlockBuilder(Object):
8487

8588
def __init__(self):
8689
self._blocks = []
90+
self._context_mod = tvm.IRModule()
8791
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate)
8892

8993
def _begin_dataflow_block(self) -> None:
9094
_ffi_api.BlockBuilderBeginDataflowBlock(self)
9195

9296
def _begin_binding_block(self) -> None:
9397
_ffi_api.BlockBuilderBeginBindingBlock(self)
94-
98+
9599
def _end_block(self) -> BindingBlock:
96100
return _ffi_api.BlockBuilderEndBlock(self)
97101

102+
def _convert_te_arg(self,
103+
te_args: Any
104+
) -> typing.Tuple[Any, List[tvm.te.Tensor]]:
105+
"""Helper function to convert Relax expressions to te tensor.
106+
In the common case, the type of te_args is a Relax expression and is converted into a te tensor.
107+
If te_args is a nested or recursive datatype (i.e list, dict, tvm.ir.Map, tvm.ir.Array),
108+
we recursive and convert any value of type Relax expression into a te tensor.
109+
Common values of type int, float, and str are preserved.
110+
111+
Parameters
112+
----------
113+
te_args : Any
114+
Argument to convert to te
115+
116+
Returns
117+
-------
118+
ret : (Any, [tvm.te.Tensor])
119+
A tuple of the converted te_args, and a list of te tensors for each converted Relax expression
120+
"""
121+
te_args_list = []
122+
123+
def _convert_te_arg_helper(arg):
124+
if isinstance(arg, Expr):
125+
arg = te_tensor(arg)
126+
te_args_list.append(arg)
127+
return arg
128+
elif isinstance(arg, (list, tvm.ir.Array)):
129+
return [_convert_te_arg_helper(x) for x in arg]
130+
elif isinstance(arg, tuple):
131+
return tuple([_convert_te_arg_helper(x) for x in arg])
132+
elif isinstance(arg, (dict, tvm.ir.Map)):
133+
for key in arg:
134+
assert isinstance(key, str), "emit_te only supports dict with string as the key currently"
135+
return {k: _convert_te_arg_helper(arg[k]) for k in arg}
136+
elif isinstance(arg, (int, float, str)):
137+
return arg
138+
else:
139+
raise TypeError("not supported type in emit_te: {}".format(type(arg)))
140+
141+
new_arg = _convert_te_arg_helper(te_args)
142+
return new_arg, te_args_list
143+
144+
def _check_te_args(self, args: List[tvm.te.Tensor]):
145+
"""check te arguments."""
146+
#TODO(hypercubestart, ziheng) support full dynamic shape in the future
147+
for x in args:
148+
for s in x.shape:
149+
if not isinstance(s, (tir.Var, tir.IntImm)):
150+
raise ValueError("emit_te not support symbolic shape"
151+
"contains expression now: {}".format(x.shape))
152+
98153
def function(self,
99154
params: Optional[Union[Var, Tuple, List[Var]]] = None,
100155
name: Optional[str] = "") -> FunctionScope:
@@ -139,7 +194,7 @@ def emit(self, call: relay.Call) -> Var:
139194
140195
Parameters
141196
----------
142-
call : tvm.relay.Call
197+
call : tvm.relax.Call
143198
The call node to be emitted.
144199
145200
Returns
@@ -149,12 +204,97 @@ def emit(self, call: relay.Call) -> Var:
149204
"""
150205
return _ffi_api.BlockBuilderEmit(self, call)
151206

207+
def emit_te(self, func: Callable, *args: Any, **kwargs: Any) -> Var:
208+
"""Emit a call node according to the te function.
209+
This function converts arguments from relax expression to te tensor,
210+
The callback func should return a te tensor.
211+
212+
Parameters
213+
----------
214+
func : Callable
215+
A function that return a te tensor.
216+
217+
Returns
218+
-------
219+
ret : tvm.relax.Var
220+
A newly created variable that gets binded to the call code.
221+
222+
Example
223+
-------
224+
225+
.. code-block:: python
226+
227+
bb = rx.BlockBuilder()
228+
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
229+
type_anno = rx.DynTensorType(2, "float32")
230+
x = rx.Var("x", [n, m], type_anno)
231+
y = rx.Var("y", [n, m], type_anno)
232+
233+
def te_func(args, args_dict, msg):
234+
A = args[0]
235+
B = args_dict["B"]
236+
return te.compute((128, 128), lambda i, j: A[i, j] + B[i, j])
237+
238+
with bb.function([x, y], "rx_func"):
239+
out = bb.emit_te(te_func, [x], {"B": y}, msg="hello")
240+
bb.emit_func_output(out)
241+
242+
will result in TVMScript
243+
244+
.. code-block:: python
245+
246+
@tvm.script.ir_module
247+
class Module:
248+
@T.prim_func
249+
def te_func(var_rxplaceholder: T.handle, var_rxplaceholder_1: T.handle, var_compute: T.handle) -> None:
250+
# function attr dict
251+
T.func_attr({"global_symbol": "te_func"})
252+
m = T.var("int64")
253+
n = T.var("int64")
254+
rxplaceholder = T.match_buffer(var_rxplaceholder, [n, m], dtype="float32")
255+
rxplaceholder_1 = T.match_buffer(var_rxplaceholder_1, [n, m], dtype="float32")
256+
compute = T.match_buffer(var_compute, [128, 128], dtype="float32")
257+
# body
258+
# with T.block("root")
259+
for i0, i1 in T.grid(128, 128):
260+
with T.block("compute"):
261+
i, j = T.axis.remap("SS", [i0, i1])
262+
T.reads([rxplaceholder[i, j], rxplaceholder_1[i, j]])
263+
T.writes([compute[i, j]])
264+
compute[i, j] = rxplaceholder[i, j] + rxplaceholder_1[i, j]
265+
266+
@R.function
267+
def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tensor:
268+
# block 0
269+
gv = relax.call_dps((128, 128), "te_func", (x, y))
270+
return gv
271+
"""
272+
new_args, te_arg_list = self._convert_te_arg(args)
273+
new_kwargs, te_kwarg_list = self._convert_te_arg(kwargs)
274+
275+
te_args = te_arg_list + te_kwarg_list
276+
self._check_te_args(te_args)
277+
278+
# TODO(hypercubestart, ziheng) handle multiple output case
279+
te_out = func(*new_args, **new_kwargs)
280+
assert isinstance(te_out, tvm.te.tensor.Tensor), "only support te tensor as function output"
281+
282+
inputs = [*te_args, te_out]
283+
tir_func = tvm.te.create_prim_func(inputs)
284+
func_name = _ffi_api.BlockBuilderGetUniqueName(self, func.__name__)
285+
tir_func = tir_func.with_attr("global_symbol", func_name)
286+
gvar = GlobalVar(func_name)
287+
self._context_mod[gvar] = tir_func
288+
call = call_dps(inputs[-1].shape, gvar, [x.op.value for x in inputs[:-1]])
289+
return _ffi_api.BlockBuilderEmit(self, call)
290+
291+
152292
def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
153293
"""Emit a MatchShape.
154294
155295
Parameters
156296
----------
157-
value : tvm.relay.Expr
297+
value : tvm.relax.Expr
158298
The value of the MatchShape to be emitted.
159299
160300
pattern : List[PrimExpr]
@@ -224,8 +364,19 @@ def get(self) -> Function:
224364
ret : tvm.relax.Function
225365
A Relax function node being built.
226366
"""
367+
# TODO(hyoercubestart, ziheng) get should return IRModule with relax + TIR functions
227368
seqe = rx.SeqExpr(self._blocks, self._func_ret)
228369
func = rx.Function(
229370
self._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._func_name)
230371
)
231372
return func
373+
374+
def context_mod(self):
375+
"""Return the context module that might contain tir functions.
376+
377+
Returns
378+
-------
379+
mod : tvm.IRModule
380+
The context module that contains tir functions during emit.
381+
"""
382+
return self._context_mod

python/tvm/relax/expr.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -169,5 +169,11 @@ def __init__(self, global_symbol: String, span: Span = None) -> None:
169169
self.__init_handle_by_constructor__(_ffi_api.ExternFunc, global_symbol, span)
170170

171171

172-
def extern(name, span: Span = None):
172+
def extern(name: str, span: Span = None):
173+
"""Create extern function."""
173174
return ExternFunc(name, span)
175+
176+
177+
def te_tensor(value: Expr, name: str = "rxplaceholder"):
178+
"""Create te tensor from relax expression."""
179+
return _ffi_api.TETensor(value, name)

python/tvm/relax/op/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
1414
# KIND, either express or implied. See the License for the
1515
# specific language governing permissions and limitations
16-
from ...ir import BaseFunc
16+
from ...ir import BaseFunc, Array
1717
from ..expr import Expr, ShapeExpr, Tuple, Call
1818
from . import _ffi_api
1919
from typing import Union, List
@@ -41,7 +41,7 @@ def call_dps(
4141
ret: Call
4242
A call node for the call_dps operator.
4343
"""
44-
if isinstance(shape, (list, tuple)):
44+
if isinstance(shape, (list, tuple, Array)):
4545
shape = ShapeExpr(shape)
4646
if isinstance(args, (list, tuple)):
4747
args = Tuple(args)

src/relax/ir/block_builder.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -550,5 +550,10 @@ TVM_REGISTER_GLOBAL("relax.BlockBuilderNormalize")
550550
return builder->Normalize(expr);
551551
});
552552

553+
TVM_REGISTER_GLOBAL("relax.BlockBuilderGetUniqueName")
554+
.set_body_typed([](BlockBuilder builder, String name_hint) {
555+
return builder->name_table()->GetUniqueName(name_hint);
556+
});
557+
553558
} // namespace relax
554559
} // namespace tvm

src/relax/ir/emit_te.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
/*!
21+
* \file relax/src/ir/emit_te.cc
22+
*/
23+
#include <tvm/relax/type.h>
24+
#include "./emit_te.h"
25+
26+
namespace tvm {
27+
namespace relax {
28+
29+
// RXPlaceholderOpNode
30+
TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
31+
.set_dispatch<RXPlaceholderOpNode>([](const ObjectRef& node, ReprPrinter* p) {
32+
auto* op = static_cast<const RXPlaceholderOpNode*>(node.get());
33+
p->stream << "rxplaceholder(" << op->name << ", " << op << ")";
34+
});
35+
36+
TVM_REGISTER_NODE_TYPE(RXPlaceholderOpNode);
37+
38+
te::Tensor TETensor(Expr value, std::string name) {
39+
auto n = make_object<RXPlaceholderOpNode>();
40+
n->name = name;
41+
n->value = value;
42+
43+
Expr shape_expr = value->shape();
44+
CHECK(shape_expr->IsInstance<ShapeExprNode>())
45+
<< "ValueError: Expression does not have an known symbolic shape, please consider use match_shape "
46+
<< "to constrain the shape before passing into te_tensor";
47+
Array<PrimExpr> shape = Downcast<ShapeExpr>(shape_expr)->values;
48+
n->shape = shape;
49+
Type type = value->checked_type();
50+
ICHECK(type->IsInstance<DynTensorTypeNode>())
51+
<< "ValueError: Expression should have a inferred DynTensorType: "
52+
<< type->GetTypeKey();
53+
DataType dtype = Downcast<DynTensorType>(type)->dtype;
54+
n->dtype = dtype;
55+
return te::PlaceholderOp(n).output(0);
56+
}
57+
58+
TVM_REGISTER_GLOBAL("relax.TETensor")
59+
.set_body_typed([](Expr value, std::string name) {
60+
return TETensor(value, name);
61+
});
62+
63+
} // namespace relax
64+
} // namespace tvm

0 commit comments

Comments
 (0)