From c58850284a52ca3bb9e59fdb29fccb373e0307a3 Mon Sep 17 00:00:00 2001 From: Tianqi Chen Date: Sat, 14 Apr 2018 19:25:26 -0700 Subject: [PATCH] Refactor, refactor code structure, fix pynq rpc (#29) --- .../resnet18/pynq/imagenet_predict.py | 16 +- vta/python/vta/__init__.py | 5 +- vta/python/vta/exec/rpc_server.py | 10 +- vta/python/vta/top/__init__.py | 5 + vta/python/vta/{ => top}/arm_conv2d.py | 10 +- vta/python/vta/{ => top}/vta_conv2d.py | 7 +- .../integration/test_benchmark_topi_conv2d.py | 155 ++++++++++++++++++ vta/tests/python/pynq/test_benchmark_topi.py | 146 ----------------- 8 files changed, 185 insertions(+), 169 deletions(-) create mode 100644 vta/python/vta/top/__init__.py rename vta/python/vta/{ => top}/arm_conv2d.py (97%) rename vta/python/vta/{ => top}/vta_conv2d.py (98%) create mode 100644 vta/tests/python/integration/test_benchmark_topi_conv2d.py delete mode 100644 vta/tests/python/pynq/test_benchmark_topi.py diff --git a/vta/examples/resnet18/pynq/imagenet_predict.py b/vta/examples/resnet18/pynq/imagenet_predict.py index eb660ea12a71..ae8d49017255 100644 --- a/vta/examples/resnet18/pynq/imagenet_predict.py +++ b/vta/examples/resnet18/pynq/imagenet_predict.py @@ -37,10 +37,10 @@ vta.program_fpga(remote, BITSTREAM_FILE) if verbose: - logging.basicConfig(level=logging.INFO) + logging.basicConfig(level=logging.DEBUG) -# Change to -device=vta-cpu to run cpu only inference. -target = "llvm -device=vta" +# Change to -device=vtacpu to run cpu only inference. +target = tvm.target.create("llvm -device=vta") target_host = "llvm -mtriple=armv7-none-linux-gnueabihf -mcpu=cortex-a9 -mattr=+neon" synset = eval(open(os.path.join(CATEG_FILE)).read()) @@ -109,7 +109,7 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()): sym = vta.graph.remove_stochastic(sym) sym = vta.graph.clean_cast(sym) sym = vta.graph.clean_conv_fuse(sym) -if "vta" in target: +if target.device_name == "vta": sym = vta.graph.pack(sym, shape_dict, factor) graph_attr.set_shape_inputs(sym, shape_dict) @@ -118,10 +118,10 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()): sym = sym.apply("InferType") with nnvm.compiler.build_config(opt_level=3): - if "vta" not in target: + if target.device_name != "vta": graph, lib, params = nnvm.compiler.build( - sym, target, shape_dict, dtype_dict, - params=params, target_host=target_host) + sym, target_host, shape_dict, dtype_dict, + params=params) else: with vta.build_config(): graph, lib, params = nnvm.compiler.build( @@ -133,7 +133,7 @@ def mark_nop(graph, conv_layer=-1, skip_conv_layer=()): lib.save(temp.relpath("graphlib.o")) remote.upload(temp.relpath("graphlib.o")) lib = remote.load_module("graphlib.o") -ctx = remote.ext_dev(0) if "vta" in target else remote.cpu(0) +ctx = remote.ext_dev(0) if target.device_name == "vta" else remote.cpu(0) print("Build complete...") diff --git a/vta/python/vta/__init__.py b/vta/python/vta/__init__.py index d0a3279f1915..4be16ccfb5a6 100644 --- a/vta/python/vta/__init__.py +++ b/vta/python/vta/__init__.py @@ -3,11 +3,12 @@ from .environment import get_env, Environment -from . import arm_conv2d, vta_conv2d -from .build_module import build_config, lower, build from .rpc_client import reconfig_runtime, program_fpga + try: + from . import top + from .build_module import build_config, lower, build from . import graph except (ImportError, RuntimeError): pass diff --git a/vta/python/vta/exec/rpc_server.py b/vta/python/vta/exec/rpc_server.py index 014b40564d4d..f3db38c929bd 100644 --- a/vta/python/vta/exec/rpc_server.py +++ b/vta/python/vta/exec/rpc_server.py @@ -75,20 +75,20 @@ def reconfig_runtime(cfg_json): pkg = PkgConfig(cfg, proj_root) # check if the configuration is already the same if os.path.isfile(cfg_path): - old_cfg = json.load(open(cfg_path)) + old_cfg = json.loads(open(cfg_path, "r").read()) if pkg.same_config(old_cfg): - logging.info("Skip reconfiguration because runtime config is the same") + logging.info("Skip reconfig_runtime due to same config.") return - cflags += ["-O2", "-std=c++11"] + cflags = ["-O2", "-std=c++11"] cflags += pkg.cflags ldflags = pkg.ldflags lib_name = dll_path - source = env.pkg_config.lib_source + source = pkg.lib_source logging.info("Rebuild runtime: output=%s, cflags=%s, source=%s, ldflags=%s", dll_path, str(cflags), str(source), str(ldflags)) cc.create_shared(lib_name, source, cflags + ldflags) with open(cfg_path, "w") as outputfile: - json.dump(pkg.cfg_json, outputfile) + outputfile.write(pkg.cfg_json) def main(): diff --git a/vta/python/vta/top/__init__.py b/vta/python/vta/top/__init__.py new file mode 100644 index 000000000000..614ed2347181 --- /dev/null +++ b/vta/python/vta/top/__init__.py @@ -0,0 +1,5 @@ +"""TVM TOPI connector, eventually most of these should go to TVM repo""" + +from .vta_conv2d import packed_conv2d, schedule_packed_conv2d +from . import vta_conv2d +from . import arm_conv2d diff --git a/vta/python/vta/arm_conv2d.py b/vta/python/vta/top/arm_conv2d.py similarity index 97% rename from vta/python/vta/arm_conv2d.py rename to vta/python/vta/top/arm_conv2d.py index 9e46ee7f8c61..c959f1ee9967 100644 --- a/vta/python/vta/arm_conv2d.py +++ b/vta/python/vta/top/arm_conv2d.py @@ -44,7 +44,7 @@ Im2ColPack(7, 4, 1, 16, False), ] -@_get_schedule.register(["tcpu", "vta"]) +@_get_schedule.register(["vtacpu", "vta"]) def _schedule_conv2d(wkl): if wkl not in _WORKLOADS: raise ValueError("no schedule for such workload: {}".format(wkl)) @@ -53,10 +53,10 @@ def _schedule_conv2d(wkl): return sch -@conv2d.register(["tcpu", "vta"]) +@conv2d.register(["vtacpu", "vta"]) def _declaration_conv2d(data, kernel, stride, padding, layout, out_dtype): - assert layout == 'NCHW', "only support NCHW convolution on tcpu" - assert data.shape[0].value == 1, "only support batch size=1 convolution on tcpu" + assert layout == 'NCHW', "only support NCHW convolution on vtacpu" + assert data.shape[0].value == 1, "only support batch size=1 convolution on vtacpu" wkl = _get_workload(data, kernel, stride, padding, out_dtype) sch = _get_schedule(wkl) return _SCH_TO_DECL_FUNC[type(sch)](data, kernel, stride, padding, out_dtype) @@ -284,7 +284,7 @@ def _schedule_im2col_conv2d(s, data, data_pad, data_col, data_vec, return s -@generic.schedule_conv2d_nchw.register(["tcpu", "vta"]) +@generic.schedule_conv2d_nchw.register(["vtacpu", "vta"]) def schedule_conv2d(outs): """Create schedule for tensors""" s = tvm.create_schedule([x.op for x in outs]) diff --git a/vta/python/vta/vta_conv2d.py b/vta/python/vta/top/vta_conv2d.py similarity index 98% rename from vta/python/vta/vta_conv2d.py rename to vta/python/vta/top/vta_conv2d.py index 0baca7ba5b5e..577eac8e143f 100644 --- a/vta/python/vta/vta_conv2d.py +++ b/vta/python/vta/top/vta_conv2d.py @@ -1,4 +1,5 @@ """Namespace for supporting packed_conv2d + ewise variant of nnvm.""" +from __future__ import absolute_import as _abs from collections import namedtuple @@ -7,7 +8,7 @@ import topi from nnvm.top import registry as reg, OpPattern -from . import environment as vta +from ..environment import get_env Workload = namedtuple("Conv2DWorkload", @@ -219,7 +220,7 @@ def _traverse(op): wrkld = _get_workload(data, pad_data, kernel, output) plan = _WL2PLAN[wrkld] - env = vta.get_env() + env = get_env() load_inp = load_wgt = load_out = store_out = env.dma_copy alu = env.alu @@ -251,7 +252,7 @@ def _traverse(op): # tile oc_factor = (plan.oc_factor if plan.oc_factor - else wrkld.out_filter // vta.BLOCK_OUT) + else plan.out_filter // env.BLOCK_OUT) h_factor = (plan.h_factor if plan.h_factor else oshape[2]) w_factor = (plan.w_factor if plan.w_factor else oshape[3]) diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py new file mode 100644 index 000000000000..0a5edfdc7518 --- /dev/null +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -0,0 +1,155 @@ +"""Testing if we can generate code in topi style""" + +import tvm +from tvm.contrib import util +from tvm.contrib.pickle_memoize import memoize +import topi +import topi.testing +import vta +import vta.testing +import numpy as np + +Workload = vta.top.vta_conv2d.Workload + +@tvm.tag_scope(tag=topi.tag.ELEMWISE) +def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.const(a_min, x.dtype) + const_max = tvm.const(a_max, x.dtype) + x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") + x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") + return x + + +def test_vta_conv2d(): + def run_vta_conv2d(env, remote, key, batch_size, wl, profile=True): + data_shape = (batch_size, wl.in_filter // env.BLOCK_IN, + wl.height, wl.width, env.BLOCK_IN) + kernel_shape = (wl.out_filter // env.BLOCK_OUT, + wl.in_filter // env.BLOCK_IN, + wl.hkernel, wl.wkernel, + env.BLOCK_OUT, env.BLOCK_IN) + bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) + + + fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 + fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 + data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) + kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) + bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) + + res_conv = vta.top.packed_conv2d( + data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) + res = topi.right_shift(res_conv, 8) + res = topi.broadcast_add(res, bias) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter + + a_shape = (batch_size, wl.in_filter, wl.height, wl.width) + w_shape = (wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel) + stride = (wl.hstride, wl.wstride) + data_dtype = data.dtype + acc_dtype = env.acc_dtype + assert wl.hpad == wl.wpad + padding = wl.hpad + + @memoize("vta.tests.test_benchmark_topi.conv2d,verify_nhwc") + def get_ref_data(): + a_np = (np.random.uniform(size=a_shape) * 4).astype(data_dtype) + w_np = (np.random.uniform(size=w_shape) * 4).astype(data_dtype) + a_np = np.abs(a_np) + w_np = np.abs(w_np) + b_np = topi.testing.conv2d_nchw_python( + a_np.astype(acc_dtype), w_np.astype(acc_dtype), stride, padding).astype(acc_dtype) + return a_np, w_np, b_np + + + def verify(s, check_correctness): + mod = vta.build(s, [data, kernel, bias, res], "ext_dev", + env.target_host, name="conv2d") + temp = util.tempdir() + + mod.save(temp.relpath("conv2d.o")) + remote.upload(temp.relpath("conv2d.o")) + f = remote.load_module("conv2d.o") + # verify + ctx = remote.ext_dev(0) + # Data in original format + data_orig, kernel_orig, res_ref = get_ref_data() + bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") + bias_orig = np.abs(bias_orig) + + data_packed = data_orig.reshape( + batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, + wl.height, wl.width).transpose((0, 1, 3, 4, 2)) + kernel_packed = kernel_orig.reshape( + wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT, + wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, + wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) + bias_packed = bias_orig.reshape( + wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) + res_shape = topi.util.get_const_tuple(res.shape) + + res_np = np.zeros(res_shape).astype(res.dtype) + data_arr = tvm.nd.array(data_packed, ctx) + kernel_arr = tvm.nd.array(kernel_packed, ctx) + bias_arr = tvm.nd.array(bias_packed, ctx) + res_arr = tvm.nd.array(res_np, ctx) + time_f = f.time_evaluator("conv2d", ctx, number=5) + cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) + res_unpack = res_arr.asnumpy().transpose( + (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) + if check_correctness: + assert wl.hpad == wl.wpad + stride = (wl.hstride, wl.wstride) + padding = wl.hpad + res_ref = res_ref >> 8 + res_ref += bias_orig.reshape(wl.out_filter, 1, 1) + res_ref = np.clip(res_ref, 0, 127).astype("int8") + np.testing.assert_allclose(res_unpack, res_ref) + return cost + + def conv_normal(print_ir): + print("----- CONV2D End-to-End Test-------") + with vta.build_config(): + s = vta.top.schedule_packed_conv2d([res]) + if print_ir: + print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) + cost = verify(s, True) + gops = (num_ops / cost.mean) / float(10 ** 9) + print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) + + conv_normal(False) + + def _run(env, remote): + # ResNet18 workloads + resnet = { + # Workloads of resnet18 on imagenet + 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), + 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), + 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), + 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), + 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), + 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), + 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), + 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), + 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), + 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), + 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), + 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), + } + + batch_size = 1 + for i in range(0, len(resnet)): + wl = resnet[i] + key = "resnet-cfg[%d]" % i + print("key=%s" % key) + print(wl) + run_vta_conv2d(env, remote, key, batch_size, wl) + vta.testing.run(_run) + + +if __name__ == "__main__": + test_vta_conv2d() diff --git a/vta/tests/python/pynq/test_benchmark_topi.py b/vta/tests/python/pynq/test_benchmark_topi.py deleted file mode 100644 index 3e2d19a67b2f..000000000000 --- a/vta/tests/python/pynq/test_benchmark_topi.py +++ /dev/null @@ -1,146 +0,0 @@ -"""Testing if we can generate code in topi style""" - -import topi -import tvm -from tvm.contrib import util, rpc -import vta -from vta import vta_conv2d -import numpy as np -import mxnet as mx - -Workload = vta_conv2d.Workload - -@tvm.tag_scope(tag=topi.tag.ELEMWISE) -def my_clip(x, a_min, a_max): - """Unlike topi's current clip, put min and max into two stages.""" - const_min = tvm.const(a_min, x.dtype) - const_max = tvm.const(a_max, x.dtype) - x = tvm.compute(x.shape, lambda *i: tvm.min(x(*i), const_max), name="clipA") - x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB") - return x - -host = "pynq" -port = 9091 -target = "llvm -target=armv7-none-linux-gnueabihf -mattr=+neon" -print_ir = False - - -def test_vta_conv2d(key, batch_size, wl, profile=True): - env = vta.get_env() - data_shape = (batch_size, wl.in_filter // env.BLOCK_IN, - wl.height, wl.width, env.BLOCK_IN) - kernel_shape = (wl.out_filter // env.BLOCK_OUT, - wl.in_filter // env.BLOCK_IN, - wl.hkernel, wl.wkernel, - env.BLOCK_OUT, env.BLOCK_IN) - bias_shape = (wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) - - - fout_height = (wl.height + 2 * wl.hpad - wl.hkernel) // wl.hstride + 1 - fout_width = (wl.width + 2 * wl.wpad - wl.wkernel) // wl.wstride + 1 - data = tvm.placeholder(data_shape, name="data", dtype=env.inp_dtype) - kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype) - bias = tvm.placeholder(bias_shape, name="kernel", dtype=env.acc_dtype) - - res_conv = vta_conv2d.packed_conv2d( - data, kernel, padding=(wl.hpad, wl.wpad), strides=(wl.hstride, wl.wstride)) - res = topi.right_shift(res_conv, 8) - res = topi.broadcast_add(res, bias) - res = my_clip(res, 0, 127) - res = topi.cast(res, "int8") - - num_ops = fout_height * fout_width * wl.hkernel * wl.wkernel * wl.out_filter * wl.in_filter - - def verify(s, check_correctness): - mod = tvm.build(s, [data, kernel, bias, res], "ext_dev", target, name="conv2d") - temp = util.tempdir() - remote = rpc.connect(host, port) - - mod.save(temp.relpath("conv2d.o")) - remote.upload(temp.relpath("conv2d.o")) - f = remote.load_module("conv2d.o") - # verify - ctx = remote.ext_dev(0) - # Data in original format - data_orig = (np.random.uniform( - size=(batch_size, wl.in_filter, wl.height, wl.width)) * 4).astype(data.dtype) - kernel_orig = (np.random.uniform( - size=(wl.out_filter, wl.in_filter, wl.hkernel, wl.wkernel)) * 4).astype(kernel.dtype) - bias_orig = (np.random.uniform(size=(wl.out_filter,)) * 4).astype("int32") - - data_orig = np.abs(data_orig) - kernel_orig = np.abs(kernel_orig) - bias_orig = np.abs(bias_orig) - - data_packed = data_orig.reshape( - batch_size, wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, - wl.height, wl.width).transpose((0, 1, 3, 4, 2)) - kernel_packed = kernel_orig.reshape( - wl.out_filter // env.BLOCK_OUT, env.BLOCK_OUT, - wl.in_filter // env.BLOCK_IN, env.BLOCK_IN, - wl.hkernel, wl.wkernel).transpose((0, 2, 4, 5, 1, 3)) - bias_packed = bias_orig.reshape( - wl.out_filter // env.BLOCK_OUT, 1, 1, env.BLOCK_OUT) - res_shape = topi.util.get_const_tuple(res.shape) - - res_np = np.zeros(res_shape).astype(res.dtype) - data_arr = tvm.nd.array(data_packed, ctx) - kernel_arr = tvm.nd.array(kernel_packed, ctx) - bias_arr = tvm.nd.array(bias_packed, ctx) - res_arr = tvm.nd.array(res_np, ctx) - time_f = f.time_evaluator("conv2d", ctx, number=10) - cost = time_f(data_arr, kernel_arr, bias_arr, res_arr) - res_unpack = res_arr.asnumpy().transpose( - (0, 1, 4, 2, 3)).reshape(batch_size, wl.out_filter, fout_height, fout_width) - if check_correctness: - res_ref = mx.nd.Convolution( - mx.nd.array(data_orig.astype(env.acc_dtype), mx.cpu(0)), - mx.nd.array(kernel_orig.astype(env.acc_dtype), mx.cpu(0)), - stride=(wl.hstride, wl.wstride), - kernel=(wl.hkernel, wl.wkernel), - num_filter=wl.out_filter, - no_bias=True, - pad=(wl.hpad, wl.wpad)).asnumpy().astype(env.acc_dtype) - res_ref = res_ref >> 8 - res_ref += bias_orig.reshape(wl.out_filter, 1, 1) - res_ref = np.clip(res_ref, 0, 127).astype("int8") - np.testing.assert_allclose(res_unpack, res_ref) - print("Correctness check pass...") - return cost - - def conv_normal(print_ir): - print("----- CONV2D End-to-End Test-------") - with vta.build_config(): - s = vta_conv2d.schedule_packed_conv2d([res]) - if print_ir: - print(vta.lower(s, [data, kernel, bias, res], simple_mode=True)) - cost = verify(s, True) - gops = (num_ops / cost.mean) / float(10 ** 9) - print("\tTime cost = %g sec/op, %g GFLOPS" % (cost.mean, gops)) - - conv_normal(print_ir) - -# ResNet18 workloads -resnet = { - # Workloads of resnet18 on imagenet - 0: Workload(224, 224, 16, 64, 7, 7, 3, 3, 2, 2), - 1: Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1), - 2: Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1), - 3: Workload(56, 56, 64, 128, 3, 3, 1, 1, 2, 2), - 4: Workload(56, 56, 64, 128, 1, 1, 0, 0, 2, 2), - 5: Workload(28, 28, 128, 128, 3, 3, 1, 1, 1, 1), - 6: Workload(28, 28, 128, 256, 3, 3, 1, 1, 2, 2), - 7: Workload(28, 28, 128, 256, 1, 1, 0, 0, 2, 2), - 8: Workload(14, 14, 256, 256, 3, 3, 1, 1, 1, 1), - 9: Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2), - 10: Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2), - 11: Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1), -} - -batch_size = 1 -for i in range(0, len(resnet)): - wl = resnet[i] - key = "resnet-cfg[%d]" % i - print "key=%s" % key - print wl - test_vta_conv2d(key, batch_size, wl)