99import os
1010import ctypes
1111import tvm
12+ from tvm ._ffi .base import c_str
1213from tvm .contrib import rpc , cc
1314
15+ from ..environment import get_env
16+
1417
1518@tvm .register_func ("tvm.contrib.rpc.server.start" , override = True )
1619def server_start ():
17- """callback when server starts ."""
20+ """VTA RPC server extension ."""
1821 # pylint: disable=unused-variable
1922 curr_path = os .path .dirname (
2023 os .path .abspath (os .path .expanduser (__file__ )))
2124 dll_path = os .path .abspath (
22- os .path .join (curr_path , "../../../lib/libvta_runtime .so" ))
25+ os .path .join (curr_path , "../../../lib/libvta .so" ))
2326 runtime_dll = []
2427 _load_module = tvm .get_global_func ("tvm.contrib.rpc.server.load_module" )
2528
26- @ tvm . register_func ( "tvm.contrib.rpc.server.load_module" , override = True )
27- def load_module ( file_name ):
29+ def load_vta_dll ():
30+ """Try to load vta dll"""
2831 if not runtime_dll :
2932 runtime_dll .append (ctypes .CDLL (dll_path , ctypes .RTLD_GLOBAL ))
33+ logging .info ("Loading VTA library: %s" , dll_path )
34+ return runtime_dll [0 ]
35+
36+ @tvm .register_func ("tvm.contrib.rpc.server.load_module" , override = True )
37+ def load_module (file_name ):
38+ load_vta_dll ()
3039 return _load_module (file_name )
3140
41+ @tvm .register_func ("device_api.ext_dev" )
42+ def ext_dev_callback ():
43+ load_vta_dll ()
44+ return tvm .get_global_func ("device_api.ext_dev" )()
45+
46+ @tvm .register_func ("tvm.contrib.vta.init" , override = True )
47+ def program_fpga (file_name ):
48+ path = tvm .get_global_func ("tvm.contrib.rpc.server.workpath" )(file_name )
49+ load_vta_dll ().VTAProgram (c_str (path ))
50+ logging .info ("Program FPGA with %s" , file_name )
51+
3252 @tvm .register_func ("tvm.contrib.rpc.server.shutdown" , override = True )
3353 def server_shutdown ():
3454 if runtime_dll :
@@ -47,17 +67,15 @@ def reconfig_runtime(cflags):
4767 if runtime_dll :
4868 raise RuntimeError ("Can only reconfig in the beginning of session..." )
4969 cflags = cflags .split ()
70+ env = get_env ()
5071 cflags += ["-O2" , "-std=c++11" ]
72+ cflags += env .pkg_config .include_path
73+ ldflags = env .pkg_config .ldflags
5174 lib_name = dll_path
52- curr_path = os .path .dirname (os .path .abspath (os .path .expanduser (__file__ )))
53- proj_root = os .path .abspath (os .path .join (curr_path , "../../../" ))
54- runtime_source = os .path .join (proj_root , "src/runtime.cc" )
55- cflags += ["-I%s/include" % proj_root ]
56- cflags += ["-I%s/nnvm/tvm/include" % proj_root ]
57- cflags += ["-I%s/nnvm/tvm/dlpack/include" % proj_root ]
58- cflags += ["-I%s/nnvm/dmlc-core/include" % proj_root ]
59- logging .info ("Rebuild runtime dll with %s" , str (cflags ))
60- cc .create_shared (lib_name , [runtime_source ], cflags )
75+ source = env .pkg_config .lib_source
76+ logging .info ("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s" ,
77+ dll_path , str (cflags ), str (source ), str (ldflags ))
78+ cc .create_shared (lib_name , source , cflags + ldflags )
6179
6280
6381def main ():
@@ -75,14 +93,6 @@ def main():
7593 help = "Report to RPC tracker" )
7694 args = parser .parse_args ()
7795 logging .basicConfig (level = logging .INFO )
78- curr_path = os .path .dirname (os .path .abspath (os .path .expanduser (__file__ )))
79- proj_root = os .path .abspath (os .path .join (curr_path , "../../../" ))
80- lib_path = os .path .abspath (os .path .join (proj_root , "lib/libvta.so" ))
81-
82- libs = []
83- for file_name in [lib_path ]:
84- libs .append (ctypes .CDLL (file_name , ctypes .RTLD_GLOBAL ))
85- logging .info ("Load additional library %s" , file_name )
8696
8797 if args .tracker :
8898 url , port = args .tracker .split (":" )
@@ -99,7 +109,6 @@ def main():
99109 args .port_end ,
100110 key = args .key ,
101111 tracker_addr = tracker_addr )
102- server .libs += libs
103112 server .proc .join ()
104113
105114if __name__ == "__main__" :
0 commit comments