Skip to content

Commit 42f068d

Browse files
tqchensergei-mironov
authored andcommitted
[RUNTIME] Simplify dynamic library and code path. (apache#27)
* [RUNTIME] Simplify dynamic library and code path. * reword the readme
1 parent 8d072d7 commit 42f068d

File tree

13 files changed

+260
-213
lines changed

13 files changed

+260
-213
lines changed

vta/Makefile

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -53,37 +53,33 @@ else
5353
NO_WHOLE_ARCH= --no-whole-archive
5454
endif
5555

56-
57-
all: lib/libvta.so lib/libvta_runtime.so
58-
5956
VTA_LIB_SRC = $(wildcard src/*.cc src/tvm/*.cc)
6057

61-
ifeq ($(TARGET), VTA_PYNQ_TARGET)
58+
ifeq ($(VTA_TARGET), pynq)
6259
VTA_LIB_SRC += $(wildcard src/pynq/*.cc)
6360
LDFLAGS += -L/usr/lib -lsds_lib
6461
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/
6562
LDFLAGS += -L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/
6663
LDFLAGS += -l:libdma.so
6764
endif
6865

69-
ifeq ($(TARGET), sim)
66+
ifeq ($(VTA_TARGET), sim)
7067
VTA_LIB_SRC += $(wildcard src/sim/*.cc)
7168
endif
7269

7370
VTA_LIB_OBJ = $(patsubst src/%.cc, build/%.o, $(VTA_LIB_SRC))
7471

72+
all: lib/libvta.so
73+
7574
build/%.o: src/%.cc
7675
@mkdir -p $(@D)
77-
$(CXX) $(CFLAGS) -MM -MT build/src/$*.o $< >build/$*.d
76+
$(CXX) $(CFLAGS) -MM -MT build/$*.o $< >build/$*.d
7877
$(CXX) -c $(CFLAGS) -c $< -o $@
7978

80-
lib/libvta.so: $(filter-out build/runtime.o, $(VTA_LIB_OBJ))
79+
lib/libvta.so: $(VTA_LIB_OBJ)
8180
@mkdir -p $(@D)
8281
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
8382

84-
lib/libvta_runtime.so: build/runtime.o
85-
@mkdir -p $(@D)
86-
$(CXX) $(CFLAGS) -shared -o $@ $(filter %.o, $^) $(LDFLAGS)
8783

8884
lint: pylint cpplint
8985

vta/README.md

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,25 @@
1-
Open Hardware/Software Stack for Vertical Deep Learning System Optimization
2-
==============================================
1+
VTA: Open, Modular, Deep Learning Accelerator Stack
2+
===================================================
33

44
[![GitHub license](http://dmlc.github.io/img/apache2.svg)](./LICENSE)
55

6-
VTA is an open hardware/software co-design stack for deep learning systems systems.
7-
It provides a customizable hardware accelerator template for deep learning inference workloads,
8-
combined with a fully functional compiler stack built with TVM.
6+
VTA(versatile tensor accelerator) is an open-source deep learning accelerator stack.
7+
It is not just an open-source hardware, but is an end to end solution that includes
8+
the entire software stack on top of VTA open-source hardware.
9+
10+
11+
The key features include:
12+
13+
- Generic, modular open-source hardware
14+
- Streamlined workflow to deploy to FPGAs.
15+
- Simulator support
16+
- Driver and JIT runtime for both simulated backend and FPGA.
17+
- End to end TVM stack integration
18+
- Direct optimization and deploy models from deep learning frameworks via TVM stack.
19+
- Customized and extendible TVM compiler backend
20+
- Flexible RPC support to ease the deployment, you can program it with python :)
21+
22+
VTA is part of our effort on [TVM Stack](http://www.tvmlang.org/).
923

1024
License
1125
-------

vta/make/config.mk

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ ADD_LDFLAGS=
2626
# the additional compile flags you want to add
2727
ADD_CFLAGS=
2828

29-
# the hardware target
30-
TARGET = pynq
29+
# the hardware target, can be [sim, pynq]
30+
VTA_TARGET = pynq
3131

3232
#---------------------
3333
# VTA hardware parameters
@@ -88,7 +88,8 @@ $(shell echo "$$(( $(VTA_LOG_ACC_BUFF_SIZE) + $(VTA_LOG_OUT_WIDTH) - $(VTA_LOG_A
8888
VTA_OUT_BUFF_SIZE = $(shell echo "$$(( 1 << $(VTA_LOG_OUT_BUFF_SIZE) ))" )
8989

9090
# Update ADD_CFLAGS
91-
ADD_CFLAGS += \
91+
ADD_CFLAGS +=
92+
-DVTA_TARGET=$(VTA_TARGET)\
9293
-DVTA_LOG_WGT_WIDTH=$(VTA_LOG_WGT_WIDTH) -DVTA_LOG_INP_WIDTH=$(VTA_LOG_INP_WIDTH) \
9394
-DVTA_LOG_ACC_WIDTH=$(VTA_LOG_ACC_WIDTH) -DVTA_LOG_OUT_WIDTH=$(VTA_LOG_OUT_WIDTH) \
9495
-DVTA_LOG_BATCH=$(VTA_LOG_BATCH) \

vta/python/vta/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""TVM-based VTA Compiler Toolchain"""
22
from __future__ import absolute_import as _abs
3+
import sys
34

45
from .environment import get_env, Environment
56

@@ -10,5 +11,5 @@
1011
from .rpc_client import reconfig_runtime, program_fpga
1112

1213
from . import graph
13-
except ImportError:
14+
except (ImportError, RuntimeError):
1415
pass

vta/python/vta/environment.py

Lines changed: 55 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@
33
from __future__ import absolute_import as _abs
44

55
import os
6+
import glob
67
import copy
7-
8-
try:
9-
# Allow missing import in config mode.
10-
import tvm
11-
from . import intrin
12-
except ImportError:
13-
pass
8+
import tvm
9+
from . import intrin
1410

1511

1612
class DevContext(object):
@@ -65,6 +61,45 @@ def get_task_qid(self, qid):
6561
return 1 if self.DEBUG_NO_SYNC else qid
6662

6763

64+
class PkgConfig(object):
65+
"""Simple package config tool for VTA.
66+
67+
This is used to provide runtime specific configurations.
68+
"""
69+
def __init__(self, env):
70+
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
71+
proj_root = os.path.abspath(os.path.join(curr_path, "../../"))
72+
# include path
73+
self.include_path = [
74+
"-I%s/include" % proj_root,
75+
"-I%s/nnvm/tvm/include" % proj_root,
76+
"-I%s/nnvm/tvm/dlpack/include" % proj_root,
77+
"-I%s/nnvm/dmlc-core/include" % proj_root
78+
]
79+
# List of source files that can be used to build standalone library.
80+
self.lib_source = []
81+
self.lib_source += glob.glob("%s/src/*.cc" % proj_root)
82+
self.lib_source += glob.glob("%s/src/%s/*.cc" % (proj_root, env.TARGET))
83+
# macro keys
84+
self.macro_defs = []
85+
for key in env.cfg_keys:
86+
self.macro_defs.append("-DVTA_%s=%s" % (key, str(getattr(env, key))))
87+
88+
if env.TARGET == "pynq":
89+
self.ldflags = [
90+
"-L/usr/lib",
91+
"-lsds_lib",
92+
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/drivers/",
93+
"-L/opt/python3.6/lib/python3.6/site-packages/pynq/lib/",
94+
"-l:libdma.so"]
95+
else:
96+
self.ldflags = []
97+
98+
@property
99+
def cflags(self):
100+
return self.include_path + self.macro_defs
101+
102+
68103
class Environment(object):
69104
"""Hareware configuration object.
70105
@@ -160,6 +195,7 @@ def __init__(self, cfg):
160195
self.mock_mode = False
161196
self._mock_env = None
162197
self._dev_ctx = None
198+
self._pkg_config = None
163199

164200
@property
165201
def dev(self):
@@ -168,6 +204,13 @@ def dev(self):
168204
self._dev_ctx = DevContext(self)
169205
return self._dev_ctx
170206

207+
@property
208+
def pkg_config(self):
209+
"""PkgConfig instance"""
210+
if self._pkg_config is None:
211+
self._pkg_config = PkgConfig(self)
212+
return self._pkg_config
213+
171214
@property
172215
def mock(self):
173216
"""A mock version of the Environment
@@ -249,7 +292,7 @@ def mem_info_wgt_buffer():
249292
head_address=None)
250293

251294
@tvm.register_func("tvm.info.mem.%s" % Environment.acc_scope)
252-
def mem_info_out_buffer():
295+
def mem_info_acc_buffer():
253296
spec = get_env()
254297
return tvm.make.node("MemoryInfo",
255298
unit_bits=spec.ACC_ELEM_BITS,
@@ -265,13 +308,15 @@ def coproc_sync(op):
265308
"int32", "VTASynchronize",
266309
get_env().dev.command_handle, 1<<31)
267310

311+
268312
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_push")
269313
def coproc_dep_push(op):
270314
return tvm.call_extern(
271315
"int32", "VTADepPush",
272316
get_env().dev.command_handle,
273317
op.args[0], op.args[1])
274318

319+
275320
@tvm.register_func("tvm.intrin.rule.default.vta.coproc_dep_pop")
276321
def coproc_dep_pop(op):
277322
return tvm.call_extern(
@@ -288,7 +333,6 @@ def _init_env():
288333

289334
for k in Environment.cfg_keys:
290335
keys.add("VTA_" + k)
291-
keys.add("TARGET")
292336

293337
if not os.path.isfile(filename):
294338
raise RuntimeError(
@@ -303,8 +347,9 @@ def _init_env():
303347
val = line.split("=")[1].strip()
304348
if k.startswith("VTA_"):
305349
k = k[4:]
350+
try:
306351
cfg[k] = int(val)
307-
else:
352+
except ValueError:
308353
cfg[k] = val
309354
return Environment(cfg)
310355

vta/python/vta/exec/rpc_server.py

Lines changed: 31 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,46 @@
99
import os
1010
import ctypes
1111
import tvm
12+
from tvm._ffi.base import c_str
1213
from tvm.contrib import rpc, cc
1314

15+
from ..environment import get_env
16+
1417

1518
@tvm.register_func("tvm.contrib.rpc.server.start", override=True)
1619
def server_start():
17-
"""callback when server starts."""
20+
"""VTA RPC server extension."""
1821
# pylint: disable=unused-variable
1922
curr_path = os.path.dirname(
2023
os.path.abspath(os.path.expanduser(__file__)))
2124
dll_path = os.path.abspath(
22-
os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
25+
os.path.join(curr_path, "../../../lib/libvta.so"))
2326
runtime_dll = []
2427
_load_module = tvm.get_global_func("tvm.contrib.rpc.server.load_module")
2528

26-
@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
27-
def load_module(file_name):
29+
def load_vta_dll():
30+
"""Try to load vta dll"""
2831
if not runtime_dll:
2932
runtime_dll.append(ctypes.CDLL(dll_path, ctypes.RTLD_GLOBAL))
33+
logging.info("Loading VTA library: %s", dll_path)
34+
return runtime_dll[0]
35+
36+
@tvm.register_func("tvm.contrib.rpc.server.load_module", override=True)
37+
def load_module(file_name):
38+
load_vta_dll()
3039
return _load_module(file_name)
3140

41+
@tvm.register_func("device_api.ext_dev")
42+
def ext_dev_callback():
43+
load_vta_dll()
44+
return tvm.get_global_func("device_api.ext_dev")()
45+
46+
@tvm.register_func("tvm.contrib.vta.init", override=True)
47+
def program_fpga(file_name):
48+
path = tvm.get_global_func("tvm.contrib.rpc.server.workpath")(file_name)
49+
load_vta_dll().VTAProgram(c_str(path))
50+
logging.info("Program FPGA with %s", file_name)
51+
3252
@tvm.register_func("tvm.contrib.rpc.server.shutdown", override=True)
3353
def server_shutdown():
3454
if runtime_dll:
@@ -47,17 +67,15 @@ def reconfig_runtime(cflags):
4767
if runtime_dll:
4868
raise RuntimeError("Can only reconfig in the beginning of session...")
4969
cflags = cflags.split()
70+
env = get_env()
5071
cflags += ["-O2", "-std=c++11"]
72+
cflags += env.pkg_config.include_path
73+
ldflags = env.pkg_config.ldflags
5174
lib_name = dll_path
52-
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
53-
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
54-
runtime_source = os.path.join(proj_root, "src/runtime.cc")
55-
cflags += ["-I%s/include" % proj_root]
56-
cflags += ["-I%s/nnvm/tvm/include" % proj_root]
57-
cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root]
58-
cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root]
59-
logging.info("Rebuild runtime dll with %s", str(cflags))
60-
cc.create_shared(lib_name, [runtime_source], cflags)
75+
source = env.pkg_config.lib_source
76+
logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s",
77+
dll_path, str(cflags), str(source), str(ldflags))
78+
cc.create_shared(lib_name, source, cflags + ldflags)
6179

6280

6381
def main():
@@ -75,14 +93,6 @@ def main():
7593
help="Report to RPC tracker")
7694
args = parser.parse_args()
7795
logging.basicConfig(level=logging.INFO)
78-
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
79-
proj_root = os.path.abspath(os.path.join(curr_path, "../../../"))
80-
lib_path = os.path.abspath(os.path.join(proj_root, "lib/libvta.so"))
81-
82-
libs = []
83-
for file_name in [lib_path]:
84-
libs.append(ctypes.CDLL(file_name, ctypes.RTLD_GLOBAL))
85-
logging.info("Load additional library %s", file_name)
8696

8797
if args.tracker:
8898
url, port = args.tracker.split(":")
@@ -99,7 +109,6 @@ def main():
99109
args.port_end,
100110
key=args.key,
101111
tracker_addr=tracker_addr)
102-
server.libs += libs
103112
server.proc.join()
104113

105114
if __name__ == "__main__":

vta/python/vta/testing/simulator.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,6 @@ def _load_lib():
1010
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
1111
dll_path = [
1212
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta.so")),
13-
os.path.abspath(os.path.join(curr_path, "../../../lib/libvta_runtime.so"))
1413
]
1514
runtime_dll = []
1615
if not all(os.path.exists(f) for f in dll_path):

vta/python/vta/testing/util.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,18 @@ def run(run_func):
1515
run_func : function(env, remote)
1616
"""
1717
env = get_env()
18-
# run on simulator
19-
if simulator.enabled():
18+
19+
# Run on local sim rpc if necessary
20+
local_rpc = int(os.environ.get("VTA_LOCAL_SIM_RPC", "0"))
21+
if local_rpc:
2022
env.TARGET = "sim"
21-
run_func(env, rpc.LocalSession())
23+
remote = rpc.connect("localhost", local_rpc)
24+
run_func(env, remote)
25+
else:
26+
# run on simulator
27+
if simulator.enabled():
28+
env.TARGET = "sim"
29+
run_func(env, rpc.LocalSession())
2230

2331
# Run on PYNQ if env variable exists
2432
pynq_host = os.environ.get("VTA_PYNQ_RPC_HOST", None)

0 commit comments

Comments
 (0)