Skip to content

Commit bf0b70a

Browse files
tqchensergei-mironov
authored andcommitted
[COMPILER] Refactor compiler to enable configuration (apache#21)
1 parent ae44dd3 commit bf0b70a

19 files changed

+977
-729
lines changed

vta/examples/resnet18/pynq/imagenet_predict.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -39,8 +39,9 @@
3939
if verbose:
4040
logging.basicConfig(level=logging.INFO)
4141

42-
# Change to -device=tcpu to run cpu only inference.
42+
# Change to -device=vta-cpu to run cpu only inference.
4343
target = "llvm -device=vta"
44+
target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon"
4445

4546
synset = eval(open(os.path.join(CATEG_FILE)).read())
4647
image = Image.open(os.path.join(TEST_FILE)).resize((224, 224))
@@ -117,15 +118,16 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()):
117118
sym = sym.apply("InferType")
118119

119120
with nnvm.compiler.build_config(opt_level=3):
120-
bdict = {}
121121
if "vta" not in target:
122-
bdict = {"add_lower_pass": []}
123-
else:
124-
bdict = {"add_lower_pass": vta.debug_mode(0)}
125-
with tvm.build_config(**bdict):
126122
graph, lib, params = nnvm.compiler.build(
127123
sym, target, shape_dict, dtype_dict,
128-
params=params)
124+
params=params, target_host=target_host)
125+
else:
126+
with vta.build_config():
127+
graph, lib, params = nnvm.compiler.build(
128+
sym, target, shape_dict, dtype_dict,
129+
params=params, target_host=target_host)
130+
129131

130132
temp = util.tempdir()
131133
lib.save(temp.relpath("graphlib.o"))

vta/python/vta/__init__.py

Lines changed: 4 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,13 @@
11
"""TVM-based VTA Compiler Toolchain"""
22
from __future__ import absolute_import as _abs
33

4-
from .hw_spec import *
4+
from .environment import get_env, Environment
55

66
try:
7-
from .runtime import SCOPE_INP, SCOPE_OUT, SCOPE_WGT, DMA_COPY, ALU
8-
from .intrin import GEVM, GEMM
9-
from .build import debug_mode
10-
from . import mock, ir_pass
7+
# allow optional import in config mode.
118
from . import arm_conv2d, vta_conv2d
12-
except AttributeError:
13-
pass
14-
15-
from .rpc_client import reconfig_runtime, program_fpga
16-
17-
try:
9+
from .build_module import build_config, lower, build
10+
from .rpc_client import reconfig_runtime, program_fpga
1811
from . import graph
1912
except ImportError:
2013
pass

vta/python/vta/build.py

Lines changed: 0 additions & 55 deletions
This file was deleted.

vta/python/vta/build_module.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""VTA specific buildin for runtime."""
2+
from __future__ import absolute_import as _abs
3+
4+
import tvm
5+
from . import ir_pass
6+
from .environment import get_env
7+
8+
9+
def lift_coproc_scope(x):
10+
"""Lift coprocessings cope to the """
11+
x = ir_pass.lift_alloc_to_scope_begin(x)
12+
x = tvm.ir_pass.LiftAttrScope(x, "coproc_scope", False)
13+
return x
14+
15+
def early_rewrite(stmt):
16+
"""Try to do storage rewrite in early pass."""
17+
try:
18+
return tvm.ir_pass.StorageRewrite(stmt)
19+
except tvm.TVMError:
20+
return stmt
21+
22+
23+
def build_config(debug_flag=0, **kwargs):
24+
"""Build a build config for VTA.
25+
26+
Parameters
27+
----------
28+
debug_flag : int
29+
The dbeug flag to be passed.
30+
31+
kwargs : dict
32+
Additional configurations.
33+
34+
Returns
35+
-------
36+
build_config: BuildConfig
37+
The build config that can be used in TVM.
38+
39+
Example
40+
--------
41+
.. code-block:: python
42+
43+
# build a vta module.
44+
with vta.build_config():
45+
vta_module = tvm.build(s, ...)
46+
"""
47+
env = get_env()
48+
def add_debug(stmt):
49+
debug = tvm.call_extern(
50+
"int32", "VTASetDebugMode",
51+
env.dev.command_handle,
52+
debug_flag)
53+
54+
return tvm.make.stmt_seq(debug, stmt)
55+
pass_list = [(1, ir_pass.inject_dma_intrin),
56+
(1, ir_pass.inject_skip_copy),
57+
(1, ir_pass.annotate_alu_coproc_scope),
58+
(1, lambda x: tvm.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
59+
(1, lift_coproc_scope),
60+
(1, ir_pass.inject_coproc_sync),
61+
(1, early_rewrite)]
62+
if debug_flag:
63+
pass_list.append((1, add_debug))
64+
pass_list.append((2, ir_pass.inject_alu_intrin))
65+
pass_list.append((3, ir_pass.fold_uop_loop))
66+
pass_list.append((3, ir_pass.cpu_access_rewrite))
67+
return tvm.build_config(add_lower_pass=pass_list, **kwargs)
68+
69+
70+
def lower(*args, **kwargs):
71+
"""Thin wrapper of tvm.lower
72+
73+
This wrapper automatically applies VTA's build_config
74+
if there is no user specified build_config in context.
75+
76+
See Also
77+
--------
78+
tvm.lower : The original TVM's lower function
79+
"""
80+
cfg = tvm.build_module.current_build_config()
81+
if not cfg.add_lower_pass:
82+
with build_config():
83+
return tvm.lower(*args, **kwargs)
84+
return tvm.lower(*args, **kwargs)
85+
86+
87+
def build(*args, **kwargs):
88+
"""Thin wrapper of tvm.build
89+
90+
This wrapper automatically applies VTA's build_config
91+
if there is no user specified build_config in context.
92+
93+
See Also
94+
--------
95+
tvm.build : The original TVM's build function
96+
"""
97+
cfg = tvm.build_module.current_build_config()
98+
if not cfg.add_lower_pass:
99+
with build_config():
100+
return tvm.build(*args, **kwargs)
101+
return tvm.build(*args, **kwargs)

0 commit comments

Comments
 (0)