Skip to content

Commit 7f7eff8

Browse files
YuchenJinyongwww
authored andcommitted
Call topi and external library through emit_te and add MLP example (apache#50)
1 parent 02bcb25 commit 7f7eff8

File tree

8 files changed

+208
-61
lines changed

8 files changed

+208
-61
lines changed

apps/relax_examples/mlp.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# Example code on creating, compiling, and running an MLP model in relax
19+
20+
21+
import tvm
22+
from tvm.relay import Call
23+
from tvm import relax, tir, topi
24+
import numpy as np
25+
26+
27+
def build_mlp(data, weight):
28+
bb = relax.BlockBuilder()
29+
30+
with bb.function([data, weight], "mlp"):
31+
gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)
32+
gv1 = bb.emit_te(topi.nn.relu, gv0)
33+
bb.emit_func_output(gv1)
34+
35+
mod = bb.get()
36+
return mod
37+
38+
39+
if __name__ == "__main__":
40+
# symbolic dimensions
41+
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
42+
# create data and weight variables
43+
data = relax.Var("data", [n, m], relax.DynTensorType(2, "float32"))
44+
weight = relax.Var("weight", [m, n], relax.DynTensorType(2, "float32"))
45+
46+
# construct a mlp model
47+
mod = build_mlp(data, weight)
48+
49+
# build and create vm executor
50+
target = tvm.target.Target("llvm")
51+
target_host = tvm.target.Target("llvm")
52+
ex, lib = relax.vm.build(mod, target, target_host)
53+
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
54+
55+
# run the mlp model on relax vm
56+
data = tvm.nd.array(np.random.rand(16, 32).astype(np.float32))
57+
weight = tvm.nd.array(np.random.rand(32, 16).astype(np.float32))
58+
res = vm["mlp"](data, weight)
59+
print(res)

python/tvm/relax/block_builder.py

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ def __exit__(self, ptype, value, trace):
4040
block = _ffi_api.BlockBuilderEndBlock(self._ib)
4141
if len(block.bindings) > 0:
4242
self._ib._blocks.append(block)
43+
seqe = rx.SeqExpr(self._ib._blocks, self._ib._func_ret)
44+
func = rx.Function(
45+
self._ib._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._ib._func_name)
46+
)
47+
gvar = rx.GlobalVar(self._ib._func_name)
48+
self._ib._context_mod[gvar] = func
49+
return func
4350

4451

4552
class DataflowScope(object):
@@ -82,7 +89,7 @@ class BlockBuilder(Object):
8289
lv1 = ib.emit(rx.multiply(lv0, y))
8390
gv0 = ib.emit_output(lv1)
8491
ib.emit_func_output(gv0)
85-
func = ib.get()
92+
mod = ib.get()
8693
"""
8794

8895
def __init__(self):
@@ -356,27 +363,12 @@ def normalize(self, expr: Expr) -> Expr:
356363
"""
357364
return _ffi_api.BlockBuilderNormalize(self, expr)
358365

359-
def get(self) -> Function:
360-
"""Return the function being built.
361-
362-
Returns
363-
-------
364-
ret : tvm.relax.Function
365-
A Relax function node being built.
366-
"""
367-
# TODO(hyoercubestart, ziheng) get should return IRModule with relax + TIR functions
368-
seqe = rx.SeqExpr(self._blocks, self._func_ret)
369-
func = rx.Function(
370-
self._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._func_name)
371-
)
372-
return func
373-
374-
def context_mod(self):
375-
"""Return the context module that might contain tir functions.
366+
def get(self) -> tvm.IRModule:
367+
"""Return the IRModule being built.
376368
377369
Returns
378370
-------
379-
mod : tvm.IRModule
380-
The context module that contains tir functions during emit.
371+
ret : tvm.IRModule
372+
An IRModule with Relax and TIR functions being built.
381373
"""
382374
return self._context_mod

src/relax/ir/block_builder.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,11 @@ BindingBlock BlockBuilderNode::EndBlock() {
297297
return ret;
298298
}
299299

300-
Optional<RelayExpr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
300+
Optional<Expr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
301+
// if the call node's shape_ is filled, return the shape directly.
302+
if (call->shape_) {
303+
return Downcast<Expr>(call->shape_.value());
304+
}
301305
auto op_map = Op::GetAttrMap<FInferShape>("FInferShape");
302306
if (call->op.as<OpNode>()) {
303307
Op op = Downcast<Op>(call->op);
@@ -309,6 +313,10 @@ Optional<RelayExpr> InferShape(const Call& call, DiagnosticContext diag_ctx) {
309313
}
310314

311315
Type InferType(const Call& call, DiagnosticContext diag_ctx) {
316+
// if the call node's checked_type_ is filled, return the type directly.
317+
if (call->checked_type_.defined()) {
318+
return call->checked_type_;
319+
}
312320
auto op_map = Op::GetAttrMap<FInferType>("FInferType");
313321
if (call->op.as<OpNode>()) {
314322
Op op = Downcast<Op>(call->op);

src/relax/op/op.cc

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,10 @@ RELAY_REGISTER_OP("relax.call_dps")
5959

6060
Expr MakeCallDPS(Expr shape, Expr func, Tuple args) {
6161
static const Op& op = Op::Get("relax.call_dps");
62-
return Call(op, {shape, func, args}, {}, {});
62+
Call call = Call(op, {shape, func, args}, {}, {});
63+
call->shape_ = shape;
64+
call->checked_type_ = args->fields[0]->checked_type_;
65+
return call;
6366
}
6467

6568
TVM_REGISTER_GLOBAL("relax.op.call_dps")

tests/python/relax/test_analysis.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,13 +47,13 @@ def test_post_order_visit():
4747
x = rx.Var("x", [m, n], dtype0)
4848
y = rx.Var("y", [n], dtype1)
4949
ib = rx.BlockBuilder()
50-
with ib.function([x, y]):
50+
with ib.function([x, y], "func"):
5151
with ib.dataflow() as df:
5252
lv0 = ib.emit(rx.op.add(x, y))
5353
lv1 = ib.emit(rx.op.multiply(lv0, y))
5454
gv0 = ib.emit_output(lv1)
5555
ib.emit_func_output(gv0)
56-
expr = ib.get()
56+
expr = ib.get()["func"]
5757

5858
names = []
5959

tests/python/relax/test_blockbuilder.py

Lines changed: 85 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from tvm import tir, te
2121
from tvm import relay
2222
from tvm import relax as rx
23-
import numpy as np
2423

2524
from tvm.ir.base import assert_structural_equal
2625
from tvm.relax import op
@@ -61,7 +60,7 @@ def test_function_single_block():
6160
y = rx.Var("y", [n], dtype1)
6261
ib = rx.BlockBuilder()
6362

64-
with ib.function([x, y]):
63+
with ib.function([x, y], "func"):
6564
with ib.dataflow() as df:
6665
lv0 = ib.emit(rx.op.add(x, y))
6766
assert lv0.name_hint == "lv"
@@ -71,7 +70,7 @@ def test_function_single_block():
7170
assert gv0.name_hint == "gv"
7271
ib.emit_func_output(gv0)
7372

74-
func = ib.get()
73+
func = ib.get()["func"]
7574
assert func.params[0] == x
7675
assert func.params[1] == y
7776
assert func.body.body == gv0
@@ -106,7 +105,7 @@ def test_function_multi_blocks():
106105
gv2 = ib.emit_output(gv1)
107106
ib.emit_func_output(gv2)
108107

109-
func = ib.get()
108+
func = ib.get()["func"]
110109
assert gv2.shape[0] == m
111110
assert gv2.shape[1] == n
112111
assert gv2.checked_type.rank == 2
@@ -121,6 +120,40 @@ def test_function_multi_blocks():
121120
assert len(func.body.blocks[2].bindings) == 2
122121

123122

123+
def test_multi_functions():
124+
m = tir.Var("m", "int32")
125+
n = tir.Var("n", "int32")
126+
dtype0 = rx.DynTensorType(rank=2, dtype="float16")
127+
dtype1 = rx.DynTensorType(rank=1, dtype="float16")
128+
x = rx.Var("x", [m, n], dtype0)
129+
y = rx.Var("y", [n], dtype1)
130+
ib = rx.BlockBuilder()
131+
132+
with ib.function([x, y], "func1"):
133+
with ib.dataflow() as df:
134+
lv0 = ib.emit(rx.op.add(x, y))
135+
assert lv0.name_hint == "lv"
136+
gv0 = ib.emit_output(lv0)
137+
ib.emit_func_output(gv0)
138+
139+
with ib.function([x, y], "func2"):
140+
with ib.dataflow() as df:
141+
lv0 = ib.emit(rx.op.add(x, y))
142+
assert lv0.name_hint == "lv"
143+
gv0 = ib.emit_output(lv0)
144+
ib.emit_func_output(gv0)
145+
146+
mod = ib.get()
147+
func1 = mod["func1"]
148+
assert func1.params[0] == x
149+
assert func1.params[1] == y
150+
assert func1.name.name_hint == "func1"
151+
func2 = mod["func2"]
152+
assert func2.params[0] == x
153+
assert func2.params[1] == y
154+
assert func2.name.name_hint == "func2"
155+
156+
124157
def test_binary_shape_type_deduction():
125158
m = tir.Var("m", "int32")
126159
n = tir.Var("n", "int32")
@@ -177,7 +210,7 @@ def test_emit_match_shape():
177210
y = rx.Var("shape_value", type_annotation=rx.ShapeType(), shape_annotation=shape_anno)
178211
ib = rx.BlockBuilder()
179212

180-
with ib.function([x, y]):
213+
with ib.function([x, y], "func"):
181214
with ib.dataflow() as df:
182215
# lv0: Tensor[(m, n), "float32"] =
183216
# match_shape(x: Tensor[_, "float32"], [m, n])
@@ -194,7 +227,7 @@ def test_emit_match_shape():
194227
gv0 = ib.emit_output(lv1)
195228

196229
ib.emit_func_output(gv0)
197-
func = ib.get()
230+
func = ib.get()["func"]
198231
block = func.body.blocks[0]
199232
b0, b1 = block.bindings[:2]
200233
assert isinstance(b0, rx.MatchShape)
@@ -248,11 +281,8 @@ def te_func(args, args_dict, msg):
248281
out = bb.emit_te(te_func, [x, y], {"C": z}, msg="hello")
249282
bb.emit_func_output(out)
250283

251-
func = bb.get()
252-
mod = bb.context_mod()
253-
254-
gvar = tvm.relay.GlobalVar("rx_func")
255-
mod[gvar] = func
284+
mod = bb.get()
285+
rx_func = mod["rx_func"]
256286

257287
def get_tir_func():
258288
A = te.placeholder((n, m), dtype="float32", name="A")
@@ -265,20 +295,20 @@ def get_tir_func():
265295
assert_structural_equal(mod["te_func"].body, get_tir_func().body)
266296

267297
# check Relax function calls TIR function with call_dps call
268-
assert func.params[0] == x
269-
assert func.params[1] == y
270-
assert func.params[2] == z
271-
assert func.name.name_hint == "rx_func"
272-
assert func.body.body == out
273-
assert len(func.body.blocks) == 1
274-
assert len(func.body.blocks[0].bindings) == 1
275-
assert isinstance(func.body.blocks[0].bindings[0].value, rx.Call)
276-
assert func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
277-
assert len(func.body.blocks[0].bindings[0].value.args) == 3
278-
assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
279-
assert func.body.blocks[0].bindings[0].value.args[2][0] == x
280-
assert func.body.blocks[0].bindings[0].value.args[2][1] == y
281-
assert func.body.blocks[0].bindings[0].value.args[2][2] == z
298+
assert rx_func.params[0] == x
299+
assert rx_func.params[1] == y
300+
assert rx_func.params[2] == z
301+
assert rx_func.name.name_hint == "rx_func"
302+
assert rx_func.body.body == out
303+
assert len(rx_func.body.blocks) == 1
304+
assert len(rx_func.body.blocks[0].bindings) == 1
305+
assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call)
306+
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
307+
assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3
308+
assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
309+
assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x
310+
assert rx_func.body.blocks[0].bindings[0].value.args[2][1] == y
311+
assert rx_func.body.blocks[0].bindings[0].value.args[2][2] == z
282312

283313

284314
def test_emit_te_multiple():
@@ -297,16 +327,45 @@ def te_func(A):
297327
y1 = bb.emit_te(te_func, y)
298328
bb.emit_func_output(y1)
299329

300-
func = bb.get()
330+
func = bb.get()["rx_func"]
301331
assert func.body.blocks[0].bindings[0].value.args[1].name_hint == "te_func"
302332
assert func.body.blocks[0].bindings[1].value.args[1].name_hint == "te_func1"
303333

334+
335+
def test_emit_te_extern():
336+
bb = rx.BlockBuilder()
337+
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
338+
type_anno = rx.DynTensorType(2, "float32")
339+
x = rx.Var("x", [n, m], type_anno)
340+
y = rx.Var("y", [m, n], type_anno)
341+
342+
with bb.function([x, y], "rx_cblas_matmul"):
343+
out = bb.emit_te(tvm.contrib.cblas.matmul, x, y, transa=False, transb=False)
344+
bb.emit_func_output(out)
345+
346+
mod = bb.get()
347+
rx_func = mod["rx_cblas_matmul"]
348+
349+
# check Relax function calls TIR function with call_dps call
350+
assert rx_func.params[0] == x
351+
assert rx_func.params[1] == y
352+
assert len(rx_func.body.blocks) == 1
353+
assert isinstance(rx_func.body.blocks[0].bindings[0].value, rx.Call)
354+
assert rx_func.body.blocks[0].bindings[0].value.op == relay.op.get("relax.call_dps")
355+
assert len(rx_func.body.blocks[0].bindings[0].value.args) == 3
356+
assert rx_func.body.blocks[0].bindings[0].value.args[1].name_hint == "matmul"
357+
assert rx_func.body.blocks[0].bindings[0].value.args[2][0] == x
358+
assert rx_func.body.blocks[0].bindings[0].value.args[2][1] == y
359+
360+
304361
if __name__ == "__main__":
305362
test_block_builder()
306363
test_function_single_block()
307364
test_function_multi_blocks()
365+
test_multi_functions()
308366
test_binary_shape_type_deduction()
309367
test_emit_match_shape()
310368
test_normalize()
311369
test_emit_te()
312370
test_emit_te_multiple()
371+
test_emit_te_extern()

0 commit comments

Comments
 (0)