Skip to content

Commit 4eeb3c8

Browse files
YuchenJinyongwww
authored andcommitted
[TESTING] pytorch-like nn.Module API to build neural network (apache#54)
* nn module * address comments. * Add nn.init_params * Remove nn.Builder and use BlockBuilder instead. * Rebase. * Refactor block builder and add tests. * Address comments. * Update.
1 parent f16f3a5 commit 4eeb3c8

File tree

9 files changed

+517
-85
lines changed

9 files changed

+517
-85
lines changed

apps/relax_examples/mlp.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
def build_mlp(data, weight):
2828
bb = relax.BlockBuilder()
2929

30-
with bb.function([data, weight], "mlp"):
30+
with bb.function("mlp", [data, weight]):
3131
gv0 = bb.emit_te(tvm.contrib.cblas.matmul, data, weight, transa=False, transb=False)
3232
gv1 = bb.emit_te(topi.nn.relu, gv0)
3333
bb.emit_func_output(gv1)
@@ -47,9 +47,8 @@ def build_mlp(data, weight):
4747
mod = build_mlp(data, weight)
4848

4949
# build and create vm executor
50-
target = tvm.target.Target("llvm")
51-
target_host = tvm.target.Target("llvm")
52-
ex, lib = relax.vm.build(mod, target, target_host)
50+
target = tvm.target.Target("llvm", host="llvm")
51+
ex, lib = relax.vm.build(mod, target)
5352
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
5453

5554
# run the mlp model on relax vm

apps/relax_examples/nn_module.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
18+
# Example code on creating, compiling, and running a neural network with pytorch-like API
19+
20+
21+
import tvm
22+
from tvm.relay import Call
23+
from tvm import relax, tir
24+
from tvm.relax.testing import nn
25+
from tvm.script import relax as R
26+
import numpy as np
27+
28+
29+
if __name__ == "__main__":
30+
builder = relax.BlockBuilder()
31+
32+
# a symbolic variable to represent minibatch size
33+
n = tir.Var("n", "int64")
34+
input_size = 784
35+
hidden_sizes = [128, 32]
36+
output_size = 10
37+
38+
# build a three linear-layer neural network for a classification task
39+
with builder.function("main"):
40+
model = nn.Sequential(
41+
nn.Linear(input_size, hidden_sizes[0]),
42+
nn.ReLU(),
43+
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
44+
nn.ReLU(),
45+
nn.Linear(hidden_sizes[1], output_size),
46+
nn.LogSoftmax(),
47+
)
48+
data = nn.Placeholder((n, input_size), name="data")
49+
output = model(data)
50+
params = [data] + model.parameters()
51+
builder.emit_func_output(output, params=params)
52+
53+
# get and print the IRmodule being built
54+
mod = builder.get()
55+
print(R.parser.astext(mod))
56+
57+
# build the IRModule and create relax vm
58+
target = tvm.target.Target("llvm", host="llvm")
59+
ex, lib = relax.vm.build(mod, target)
60+
vm = relax.VirtualMachine(ex, tvm.cpu(), mod=lib)
61+
62+
# init parameters
63+
params = nn.init_params(mod)
64+
65+
# run the model on relax vm
66+
# the input data has a minibatch size of 3
67+
data = tvm.nd.array(np.random.rand(3, input_size).astype(np.float32))
68+
res = vm["main"](data, *params)
69+
print(res)

python/tvm/relax/block_builder.py

Lines changed: 143 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -30,42 +30,37 @@
3030
class FunctionScope(object):
3131
"""Auxiliary scope for function"""
3232

33-
def __init__(self, irbuilder):
34-
self._ib = irbuilder
33+
def __init__(self, block_builder, name, params):
34+
self._bb = block_builder
35+
self._name = name
36+
self._params = params
3537

3638
def __enter__(self):
37-
_ffi_api.BlockBuilderBeginBindingBlock(self._ib)
39+
self._bb._enter_function_scope(self._name, self._params)
3840

39-
def __exit__(self, ptype, value, trace):
40-
block = _ffi_api.BlockBuilderEndBlock(self._ib)
41-
if len(block.bindings) > 0:
42-
self._ib._blocks.append(block)
43-
seqe = rx.SeqExpr(self._ib._blocks, self._ib._func_ret)
44-
func = rx.Function(
45-
self._ib._func_params, seqe, rx.DynTensorType(-1, "float32"), rx.GlobalVar(self._ib._func_name)
46-
)
47-
gvar = rx.GlobalVar(self._ib._func_name)
48-
self._ib._context_mod[gvar] = func
49-
return func
41+
def __exit__(self, exc_type, exc_val, exc_tb):
42+
# __exit__ should properly handle the case where the with block exits with an exception
43+
# when handling error case in exit, always check if there is already an exception been thrown in the with block
44+
self._bb._exit_function_scope(exc_type, exc_val, exc_tb)
5045

5146

5247
class DataflowScope(object):
5348
"""Auxiliary scope for Dataflow block"""
5449

55-
def __init__(self, irbuilder):
56-
self._ib = irbuilder
50+
def __init__(self, block_builder):
51+
self._bb = block_builder
5752

5853
def __enter__(self):
59-
block = _ffi_api.BlockBuilderEndBlock(self._ib)
54+
block = self._bb._end_block()
6055
if len(block.bindings) > 0:
61-
self._ib._blocks.append(block)
62-
_ffi_api.BlockBuilderBeginDataflowBlock(self._ib)
56+
self._bb._blocks.append(block)
57+
self._bb._begin_dataflow_block()
6358

6459
def __exit__(self, ptype, value, trace):
65-
block = _ffi_api.BlockBuilderEndBlock(self._ib)
60+
block = self._bb._end_block()
6661
if len(block.bindings) > 0:
67-
self._ib._blocks.append(block)
68-
_ffi_api.BlockBuilderBeginBindingBlock(self._ib)
62+
self._bb._blocks.append(block)
63+
self._bb._begin_binding_block()
6964

7065

7166
@tvm._ffi.register_object("relax.BlockBuilder")
@@ -82,19 +77,55 @@ class BlockBuilder(Object):
8277
dtype1 = rx.DynTensorType(rank=1, dtype="float16")
8378
x = rx.Var("x", [m, n], dtype0)
8479
y = rx.Var("y", [n], dtype1)
85-
ib = rx.BlockBuilder()
86-
with ib.function([x, y], "func"):
87-
with ib.dataflow() as df:
88-
lv0 = ib.emit(rx.add(x, y))
89-
lv1 = ib.emit(rx.multiply(lv0, y))
90-
gv0 = ib.emit_output(lv1)
91-
ib.emit_func_output(gv0)
92-
mod = ib.get()
80+
bb = rx.BlockBuilder()
81+
with bb.function([x, y], "func"):
82+
with bb.dataflow() as df:
83+
lv0 = bb.emit(rx.add(x, y))
84+
lv1 = bb.emit(rx.multiply(lv0, y))
85+
gv0 = bb.emit_output(lv1)
86+
bb.emit_func_output(gv0)
87+
mod = bb.get()
88+
89+
BlockBuilder can also be used to contruct neural networks with nn.Module API
90+
91+
.. code-block:: python
92+
93+
from tvm.relax.testing import nn
94+
95+
n = tir.Var("n", "int64")
96+
input_size = 784
97+
hidden_sizes = [128, 32]
98+
output_size = 10
99+
bb = rx.BlockBuilder()
100+
101+
with bb.function("main"):
102+
model = nn.Sequential(
103+
nn.Linear(input_size, hidden_sizes[0]),
104+
nn.ReLU(),
105+
nn.Linear(hidden_sizes[0], hidden_sizes[1]),
106+
nn.ReLU(),
107+
nn.Linear(hidden_sizes[1], output_size),
108+
nn.LogSoftmax(),
109+
)
110+
data = nn.Placeholder((n, input_size), name="data")
111+
output = model(data)
112+
params = [data] + model.parameters()
113+
builder.emit_func_output(output, params=params)
114+
mod = bb.get()
93115
"""
94116

117+
_current = None
118+
119+
@staticmethod
120+
def current():
121+
"""Returns the current BlockBuilder."""
122+
return BlockBuilder._current
123+
95124
def __init__(self):
96125
self._blocks = []
97126
self._context_mod = tvm.IRModule()
127+
# a boolean flag that tracks if emit_func_output has been called
128+
self._is_emit_func_output_called = False;
98129
self.__init_handle_by_constructor__(_ffi_api.BlockBuilderCreate)
99130

100131
def _begin_dataflow_block(self) -> None:
@@ -105,6 +136,22 @@ def _begin_binding_block(self) -> None:
105136

106137
def _end_block(self) -> BindingBlock:
107138
return _ffi_api.BlockBuilderEndBlock(self)
139+
140+
def _enter_function_scope(self, name, params):
141+
if BlockBuilder.current() is not None:
142+
raise RuntimeError("BlockBuilder does not allow nested functions.")
143+
BlockBuilder._current = self
144+
self._func_name = name
145+
self._func_params = params
146+
self._begin_binding_block()
147+
148+
def _exit_function_scope(self, exc_type, exc_val, exc_tb):
149+
if exc_type is None:
150+
if not self._is_emit_func_output_called:
151+
raise RuntimeError("emit_func_output must be called in a relax function.")
152+
153+
self._is_emit_func_output_called = False
154+
BlockBuilder._current = None
108155

109156
def _convert_te_arg(self,
110157
te_args: Any
@@ -173,31 +220,36 @@ def _populate_used_vars(expr):
173220

174221

175222
def function(self,
176-
params: Optional[Union[Var, Tuple, List[Var]]] = None,
177-
name: Optional[str] = "") -> FunctionScope:
223+
name: str,
224+
params: Optional[Union[Var, Tuple, List[Var]]] = None) -> FunctionScope:
178225
"""Annotate a Relax function.
179226
180227
Parameters
181228
----------
229+
name : str, optional
230+
The name of the function
231+
182232
params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
183233
The parameters of the function.
184-
185-
name : str, optional
186-
The name of the function. If provided, the function is global, otherwise local.
234+
If params is None, it means deferring initialization of function parameters until emit_func_output.
187235
188236
Returns
189237
-------
190238
ret: FunctionScope
191239
A FunctionScope for building a Relax function node.
192240
"""
193241
if not params:
194-
params = []
195-
if not isinstance(params, (list, tuple)):
242+
params = None
243+
elif isinstance(params, rx.Var):
196244
params = [params]
245+
elif isinstance(params, (list, tuple)):
246+
for param in params:
247+
if not isinstance(param, rx.Var):
248+
raise TypeError("each element of function parameters must be of type tvm.relax.Var,\
249+
but got: {}".format(type(param)))
197250

198-
self._func_params = params
199-
self._func_name = name
200-
return FunctionScope(self)
251+
name = self.get_unique_name(name)
252+
return FunctionScope(self, name, params)
201253

202254
def dataflow(self) -> DataflowScope:
203255
"""Annotate a Relax dataflow block.
@@ -304,12 +356,12 @@ def rx_func(x: Tensor[(n, m), "float32"], y: Tensor[(n, m), "float32"]) -> Tenso
304356

305357
inputs = [*te_args, te_out]
306358
tir_func = tvm.te.create_prim_func(inputs)
307-
func_name = _ffi_api.BlockBuilderGetUniqueName(self, func.__name__)
359+
func_name = self.get_unique_name(func.__name__)
308360
tir_func = tir_func.with_attr("global_symbol", func_name)
309361
gvar = GlobalVar(func_name)
310362
self._context_mod[gvar] = tir_func
311363
call = call_dps(inputs[-1].shape, gvar, [x.op.value for x in inputs[:-1]])
312-
return _ffi_api.BlockBuilderEmit(self, call)
364+
return self.emit(call)
313365

314366

315367
def match_shape(self, value: Expr, pattern: List[PrimExpr]) -> Var:
@@ -347,22 +399,54 @@ def emit_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
347399
output = Tuple(output)
348400
return _ffi_api.BlockBuilderEmitOutput(self, output)
349401

350-
def emit_func_output(self, output: Union[Expr, Tuple, List[Expr]]) -> None:
402+
def emit_func_output(self,
403+
output: Union[Expr, Tuple, List[Expr]],
404+
params: Optional[Union[Var, Tuple, List[Var]]] = None) -> None:
351405
"""Emit output for the function.
352406
353407
Parameters
354408
----------
355409
output : Expr | Tuple | List[Expr]
356410
The output of the current block/function.
411+
412+
params : tvm.relax.Var | Tuple | List[tvm.relax.Var], optional
413+
The parameters of the function to be built.
414+
If params is None, it means the params have been initialized in the function with scope.
357415
358416
Returns
359417
-------
360418
ret : tvm.relax.Var
361419
The return variable which gets binded to the output.
362420
"""
421+
if self._is_emit_func_output_called:
422+
raise RuntimeError("emit_func_output must be called exactly once in a relax function.")
423+
self._is_emit_func_output_called = True
424+
425+
if self._func_params is not None and params is not None:
426+
raise RuntimeError("function parameters have been initialized in the function with scope.")
427+
428+
if self._func_params is None and params is None:
429+
raise RuntimeError("Relax function must have parameter.")
430+
431+
if self._func_params is None:
432+
self._func_params = params
433+
434+
if BlockBuilder.current() is not self:
435+
raise RuntimeError("BlockBuilder._current must be self.")
436+
363437
if isinstance(output, (list, tuple)):
364438
output = Tuple(output)
365439
self._func_ret = output
440+
441+
block = self._end_block()
442+
if len(block.bindings) > 0:
443+
self._blocks.append(block)
444+
seqe = rx.SeqExpr(self._blocks, self._func_ret)
445+
func = rx.Function(
446+
self._func_params, seqe, rx.DynTensorType(-1), rx.GlobalVar(self._func_name)
447+
)
448+
gvar = rx.GlobalVar(self._func_name)
449+
self._context_mod[gvar] = func
366450

367451
def normalize(self, expr: Expr) -> Expr:
368452
"""Normalize an Expr to complete its shape and type.
@@ -388,3 +472,19 @@ def get(self) -> tvm.IRModule:
388472
An IRModule with Relax and TIR functions being built.
389473
"""
390474
return self._context_mod
475+
476+
477+
def get_unique_name(self, name_prefix: str) -> str:
478+
"""Generate a unique name with a specified prefix.
479+
480+
Parameters
481+
----------
482+
name_hint : str
483+
The name prefix.
484+
485+
Returns
486+
-------
487+
ret : str
488+
The generated name.
489+
"""
490+
return _ffi_api.BlockBuilderGetUniqueName(self, name_prefix)

0 commit comments

Comments
 (0)