Skip to content

Commit daffce5

Browse files
committed
[HARDWARE] [AUTOTVM] Allowing AutoTVM to tune for different VTA designs and adding ZCU102 support (apache#19)
* compilation support for ZCU102 * dll caching to save on dynamic reconfiguration time * introducing model signature to differentiate between VTA variants * disable tophub to avoid falling back on invalid schedules * reconfig when targeting an FPGA to reset the hardware * typo fix * addressing comments
1 parent 7055803 commit daffce5

File tree

10 files changed

+98
-38
lines changed

10 files changed

+98
-38
lines changed

python/tvm/autotvm/measure/measure_methods.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,10 @@ def run_through_rpc(measure_input, build_result,
461461
try:
462462
# upload built module
463463
remote = request_remote(*remote_args)
464+
if measure_input.target.device_name == 'vta':
465+
from vta import program_fpga, reconfig_runtime
466+
program_fpga(remote, None)
467+
reconfig_runtime(remote)
464468
remote.upload(build_result.filename)
465469
func = remote.load_module(os.path.split(build_result.filename)[1])
466470
ctx = remote.context(str(measure_input.target), 0)

python/tvm/autotvm/task/nnvm_integration.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
from ..util import get_const_tuple
1414
from .task import create, register
15+
from .dispatcher import ApplyHistoryBest
1516

1617
logger = logging.getLogger('autotvm')
1718

@@ -240,7 +241,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
240241

241242
# run compiler to collect all TOPI calls during compilation
242243
nnvm.compiler.engine.clear_cache()
243-
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
244+
with ApplyHistoryBest([]):
245+
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
244246
nnvm.compiler.engine.clear_cache()
245247

246248
logger.disabled = old_state

vta/hardware/xilinx/Makefile

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,8 @@ $(BIT_PATH): $(IP_PATH)
102102
mkdir -p $(HW_BUILD_PATH)
103103
cd $(HW_BUILD_PATH) && \
104104
$(VIVADO) -mode tcl -source $(SCRIPT_DIR)/ultra96.tcl \
105-
-tclargs $(BUILD_DIR)/hls/$(CONF) $(VTA_HW_COMP_THREADS) $(VTA_CLOCK_FREQ) $(VTA_GEMM_II) \
105+
-tclargs $(VTA_TARGET) $(BUILD_DIR)/hls/$(CONF) $(VTA_HW_COMP_THREADS) \
106+
$(VTA_CLOCK_FREQ) $(VTA_GEMM_II) \
106107
$(VTA_INP_WIDTH) $(VTA_WGT_WIDTH) $(VTA_OUT_WIDTH) \
107108
$(VTA_BATCH) $(VTA_IN_BLOCK) $(VTA_OUT_BLOCK) \
108109
$(VTA_INP_BUFF_SIZE) $(VTA_WGT_BUFF_SIZE) $(VTA_OUT_BUFF_SIZE)

vta/hardware/xilinx/scripts/hls.tcl

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,13 +93,21 @@ proc init_design {target per g_ii a_ii inp_width wgt_width out_width acc_width b
9393
set_part {xc7z020clg484-1}
9494
} elseif {$target=="ultra96"} {
9595
set_part {xczu3eg-sbva484-1-e}
96+
} elseif {$target=="zcu102"} {
97+
set_part {xczu9eg-ffvb1156-2-e}
9698
}
9799

98100
# Max bus width (supported by Vivado)
99101
set max_width 1024
100102

101103
# Set axi width (TODO derive from top level config)
102-
set axi_width 128
104+
if {$target=="pynq"} {
105+
set axi_width 64
106+
} elseif {$target=="ultra96"} {
107+
set axi_width 128
108+
} elseif {$target=="zcu102"} {
109+
set axi_width 128
110+
}
103111

104112
# Set the clock frequency
105113
create_clock -period $per -name default

vta/hardware/xilinx/scripts/ultra96.tcl

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -21,23 +21,24 @@ if { [string first $scripts_vivado_version $current_vivado_version] == -1 } {
2121
}
2222

2323
# Parse argument list, derive the clock to utilize
24-
if { [llength $argv] eq 13 } {
25-
set ip_path [lindex $argv 0]
26-
set num_threads [lindex $argv 1]
27-
set clock_freq [lindex $argv 2]
28-
set gemm_ii [lindex $argv 3]
29-
set inp_width [expr 1 << [lindex $argv 4]]
30-
set wgt_width [expr 1 << [lindex $argv 5]]
31-
set out_width [expr 1 << [lindex $argv 6]]
32-
set batch [expr 1 << [lindex $argv 7]]
33-
set out_block [expr 1 << [lindex $argv 8]]
34-
set in_block [expr 1 << [lindex $argv 9]]
35-
set inp_mem_size [expr 1 << [lindex $argv 10]]
36-
set wgt_mem_size [expr 1 << [lindex $argv 11]]
37-
set out_mem_size [expr 1 << [lindex $argv 12]]
24+
if { [llength $argv] eq 14 } {
25+
set target [lindex $argv 0]
26+
set ip_path [lindex $argv 1]
27+
set num_threads [lindex $argv 2]
28+
set clock_freq [lindex $argv 3]
29+
set gemm_ii [lindex $argv 4]
30+
set inp_width [expr 1 << [lindex $argv 5]]
31+
set wgt_width [expr 1 << [lindex $argv 6]]
32+
set out_width [expr 1 << [lindex $argv 7]]
33+
set batch [expr 1 << [lindex $argv 8]]
34+
set out_block [expr 1 << [lindex $argv 9]]
35+
set in_block [expr 1 << [lindex $argv 10]]
36+
set inp_mem_size [expr 1 << [lindex $argv 11]]
37+
set wgt_mem_size [expr 1 << [lindex $argv 12]]
38+
set out_mem_size [expr 1 << [lindex $argv 13]]
3839
} else {
39-
puts "Arg list incomplete: <path to ip dir> <num threads> <clock freq> <gemm ii> \
40-
<inp width> <wgt_width> <out_width> <batch> <batch> <out_block> <in_block
40+
puts "Arg list incomplete: <target> <path to ip dir> <num threads> <clock freq> \
41+
<gemm ii> <inp width> <wgt_width> <out_width> <batch> <batch> <out_block> <in_block> \
4142
<inp_mem_size> <wgt_mem_size> <out_mem_size>"
4243
return 1
4344
}
@@ -82,8 +83,13 @@ set compute_ip "${ip_path}/vta_compute/solution0/impl/ip/xilinx_com_hls_compute_
8283
set store_ip "${ip_path}/vta_store/solution0/impl/ip/xilinx_com_hls_store_1_0.zip"
8384

8485
# Create custom project
85-
create_project -force $proj_name $proj_path -part xczu3eg-sbva484-1-e
86-
set_property BOARD_PART em.avnet.com:ultra96:part0:1.0 [current_project]
86+
if { ${target} eq "ultra96" } {
87+
create_project -force $proj_name $proj_path -part xczu3eg-sbva484-1-e
88+
set_property BOARD_PART em.avnet.com:ultra96:part0:1.0 [current_project]
89+
} elseif { ${target} eq "zcu102" } {
90+
create_project -force $proj_name $proj_path -part xczu9eg-ffvb1156-2-e
91+
set_property BOARD_PART xilinx.com:zcu102:part0:3.2 [current_project]
92+
}
8793

8894
# Update IP repository with generated IP
8995
file mkdir $ip_lib

vta/python/vta/environment.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ def __init__(self, cfg):
149149
self._mock_env = None
150150
self._dev_ctx = None
151151
self._last_env = None
152-
# derive bitstream name
152+
# derive bitstream name
153153
self.BITSTREAM = "{}/{}/{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}_{}MHz_{}ns_gii{}".format(
154154
self.HW_VER.replace('.', '_'),
155155
self.TARGET,
@@ -171,6 +171,21 @@ def __init__(self, cfg):
171171
if self.MUL_EN and self.ALU_EN:
172172
self.BITSTREAM += "_mul"
173173
self.BITSTREAM += ".bit"
174+
# model - autoTVM signature that identifies VTA configuration.
175+
# This is WIP: knobs that could influence the efficacy of the
176+
# schedule have been left out for now.
177+
self.MODEL = "{}-{}x{}x{}_a{}w{}o{}_{}_{}_{}_{}".format(
178+
self.TARGET,
179+
self.BATCH,
180+
self.BLOCK_IN,
181+
self.BLOCK_OUT,
182+
self.INP_WIDTH,
183+
self.WGT_WIDTH,
184+
self.OUT_WIDTH,
185+
self.LOG_UOP_BUFF_SIZE,
186+
self.LOG_INP_BUFF_SIZE,
187+
self.LOG_WGT_BUFF_SIZE,
188+
self.LOG_ACC_BUFF_SIZE)
174189

175190
def __enter__(self):
176191
self._last_env = Environment.current

vta/python/vta/exec/rpc_server.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import ctypes
1111
import json
1212
import tvm
13+
from shutil import copyfile
1314
from tvm._ffi.base import c_str
1415
from tvm import rpc
1516
from tvm.contrib import cc
@@ -87,14 +88,24 @@ def reconfig_runtime(cfg_json):
8788
if pkg.same_config(old_cfg):
8889
logging.info("Skip reconfig_runtime due to same config.")
8990
return
90-
cflags = ["-O2", "-std=c++11"]
91-
cflags += pkg.cflags
92-
ldflags = pkg.ldflags
93-
lib_name = dll_path
94-
source = pkg.lib_source
95-
logging.info("Rebuild runtime:\n output=%s,\n cflags=%s,\n source=%s,\n ldflags=%s",
96-
dll_path, '\n\t'.join(cflags), '\n\t'.join(source), '\n\t'.join(ldflags))
97-
cc.create_shared(lib_name, source, cflags + ldflags)
91+
# check if a dll matching the configuration has been cached
92+
dll_root, dll_ext = os.path.splitext(dll_path)
93+
cached_dll_path = dll_root + '-' + pkg.signature + dll_ext
94+
if os.path.isfile(cached_dll_path):
95+
copyfile(cached_dll_path, dll_path)
96+
logging.info("Swapping in cached dll: source=%s, destination=%s",
97+
cached_dll_path, dll_path)
98+
else:
99+
cflags = ["-O2", "-std=c++11"]
100+
cflags += pkg.cflags
101+
ldflags = pkg.ldflags
102+
lib_name = dll_path
103+
source = pkg.lib_source
104+
logging.info("Rebuild runtime:\n output=%s,\n cflags=%s,\n source=%s,\n ldflags=%s",
105+
dll_path, '\n\t'.join(cflags), '\n\t'.join(source), '\n\t'.join(ldflags))
106+
cc.create_shared(lib_name, source, cflags + ldflags)
107+
copyfile(dll_path, cached_dll_path)
108+
logging.info("Caching dll to: %s", cached_dll_path)
98109
with open(cfg_path, "w") as outputfile:
99110
outputfile.write(pkg.cfg_json)
100111

vta/python/vta/pkg_config.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,21 @@ def __init__(self, cfg, proj_root):
8080
def cflags(self):
8181
return self.include_path + self.macro_defs
8282

83+
@property
84+
def signature(self):
85+
return "{}-{}_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
86+
self.cfg_dict["TARGET"],
87+
self.cfg_dict["LOG_BATCH"],
88+
self.cfg_dict["LOG_BLOCK_IN"],
89+
self.cfg_dict["LOG_BLOCK_OUT"],
90+
self.cfg_dict["LOG_INP_WIDTH"],
91+
self.cfg_dict["LOG_WGT_WIDTH"],
92+
self.cfg_dict["LOG_OUT_WIDTH"],
93+
self.cfg_dict["LOG_UOP_BUFF_SIZE"],
94+
self.cfg_dict["LOG_INP_BUFF_SIZE"],
95+
self.cfg_dict["LOG_WGT_BUFF_SIZE"],
96+
self.cfg_dict["LOG_ACC_BUFF_SIZE"])
97+
8398
@property
8499
def cfg_json(self):
85100
return json.dumps(self.cfg_dict, indent=2)

vta/scripts/tune_conv.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,11 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
4949
return s, [data, kernel, bias, res]
5050

5151
if __name__ == '__main__':
52-
model = env.TARGET
5352
N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype = \
5453
1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), 'int8', 'int32'
5554

5655
task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype),
57-
target=tvm.target.vta(model), target_host=env.target_host, template_key='direct')
56+
target=tvm.target.vta(env.MODEL), target_host=env.target_host, template_key='direct')
5857
print(task.config_space)
5958

6059
# logging config (for printing tuning log to the screen)
@@ -63,7 +62,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
6362

6463
measure_option = autotvm.measure_option(
6564
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
66-
runner=autotvm.RPCRunner(model, 'fleet', 9190, number=4, repeat=3, timeout=30,
65+
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190, number=4, repeat=3, timeout=30,
6766
check_correctness=True))
6867

6968
tuner = autotvm.tuner.RandomTuner(task)

vta/scripts/tune_resnet.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,17 @@ def tune_tasks(tasks,
159159
os.remove(tmp_log_file)
160160

161161
if __name__ == '__main__':
162-
device_key = env.TARGET
163162

164163
tuning_opt = {
165-
'log_filename': 'resnet-18.log',
164+
'log_filename': 'resnet-18-{}.log'.format(env.MODEL),
166165

167166
'tuner': 'random',
168167
'n_trial': 1e9,
169168
'early_stopping': None,
170169

171170
'measure_option': autotvm.measure_option(
172171
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
173-
runner=autotvm.RPCRunner(device_key, 'fleet', 9190,
172+
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190,
174173
number=4, repeat=3, timeout=60,
175174
check_correctness=True))
176175
}
@@ -182,7 +181,7 @@ def tune_tasks(tasks,
182181
register_vta_tuning_tasks()
183182

184183
print("Extract tasks...")
185-
target = tvm.target.vta(device_key)
184+
target = tvm.target.vta(env.MODEL)
186185
target_host = env.target_host
187186
tasks = extract_tasks(sym, params, target, target_host)
188187

@@ -203,7 +202,7 @@ def tune_tasks(tasks,
203202

204203
# upload module to device
205204
print("Upload...")
206-
remote = autotvm.measure.request_remote(device_key, 'fleet', 9190, timeout=10000)
205+
remote = autotvm.measure.request_remote(env.TARGET, 'fleet', 9190, timeout=10000)
207206
remote.upload(tmp.relpath(filename))
208207
rlib = remote.load_module(filename)
209208

0 commit comments

Comments
 (0)