From 56c8a9b4b6813bc9642c2ef82d94c2bf0cb443cf Mon Sep 17 00:00:00 2001 From: ZHANG Hao Date: Wed, 21 Apr 2021 00:22:01 +0800 Subject: [PATCH] [VTA][OpenCL] intelfocl (#6126) * intelfocl support * disable tsim test * bugfix to vta autotvm * disable tsim test in task_python_vta_tsim.sh * fix integration test * update vta submodule and re-enable tsim tests * remove unnecessary comments --- cmake/modules/VTA.cmake | 8 + python/tvm/autotvm/task/topi_integration.py | 11 +- python/tvm/relay/op/strategy/generic.py | 9 + python/tvm/relay/testing/tf.py | 2 +- python/tvm/topi/x86/bitserial_dense.py | 2 +- src/relay/backend/compile_engine.cc | 4 +- src/runtime/workspace_pool.cc | 1 - src/tir/transforms/lower_tvm_builtin.cc | 10 - vta/python/vta/autotvm.py | 2 +- vta/python/vta/environment.py | 4 +- vta/python/vta/program_bitstream.py | 14 +- vta/python/vta/rpc_client.py | 14 +- vta/python/vta/testing/simulator.py | 8 +- vta/python/vta/testing/utils.py | 2 +- vta/python/vta/top/graphpack.py | 2 +- vta/python/vta/top/op.py | 138 +++++++- vta/python/vta/transform.py | 6 +- vta/runtime/runtime.cc | 134 +++++++- vta/runtime/runtime.h | 2 + .../integration/test_benchmark_topi_conv2d.py | 2 +- vta/tutorials/autotvm/tune_alu_vta.py | 320 ++++++++++++++++++ .../frontend/deploy_classification.py | 27 +- vta/tutorials/vta_get_started.py | 6 +- 23 files changed, 676 insertions(+), 52 deletions(-) create mode 100644 vta/tutorials/autotvm/tune_alu_vta.py diff --git a/cmake/modules/VTA.cmake b/cmake/modules/VTA.cmake index 58b58d231d83..64a8986f0de0 100644 --- a/cmake/modules/VTA.cmake +++ b/cmake/modules/VTA.cmake @@ -104,6 +104,10 @@ elseif(PYTHON) find_library(__cma_lib NAMES cma PATH /usr/lib) elseif(${VTA_TARGET} STREQUAL "de10nano") # DE10-Nano rules file(GLOB FPGA_RUNTIME_SRCS ${VTA_HW_PATH}/src/de10nano/*.cc ${VTA_HW_PATH}/src/*.cc) + elseif(${VTA_TARGET} STREQUAL "intelfocl") # Intel OpenCL for FPGA rules + file(GLOB FOCL_SRC ${VTA_HW_PATH}/src/oclfpga/*.cc) + list(APPEND FPGA_RUNTIME_SRCS ${FOCL_SRC}) + list(APPEND FPGA_RUNTIME_SRCS ${VTA_HW_PATH}/src/vmem/virtual_memory.cc ${VTA_HW_PATH}/src/vmem/virtual_memory.h) endif() # Target lib: vta add_library(vta SHARED ${FPGA_RUNTIME_SRCS}) @@ -123,6 +127,10 @@ elseif(PYTHON) target_include_directories(vta SYSTEM PUBLIC 3rdparty) target_include_directories(vta SYSTEM PUBLIC "/usr/local/intelFPGA_lite/18.1/embedded/ds-5/sw/gcc/arm-linux-gnueabihf/include") + elseif(${VTA_TARGET} STREQUAL "intelfocl") # Intel OpenCL for FPGA rules + target_include_directories(vta PUBLIC 3rdparty) + set (CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -std=c++17") + target_link_libraries(vta -lOpenCL) endif() endif() diff --git a/python/tvm/autotvm/task/topi_integration.py b/python/tvm/autotvm/task/topi_integration.py index f6ca3b179824..2558c7669ac9 100644 --- a/python/tvm/autotvm/task/topi_integration.py +++ b/python/tvm/autotvm/task/topi_integration.py @@ -227,7 +227,7 @@ def _decorate(topi_schedule): @_register_task_schedule(task_name) def wrapper(outs, *args, **kwargs): """wrapper function for topi schedule""" - workload = get_workload(outs) + workload = get_workload(outs, task_name) if workload is None: raise RuntimeError("Cannot find workload in attribute of this schedule") tgt = Target.current() @@ -241,18 +241,21 @@ def wrapper(outs, *args, **kwargs): return _decorate -def get_workload(outs): +def get_workload(outs, task_name=None): """Retrieve the workload from outputs""" def traverse(tensors): """traverse all ops to find attached workload""" for t in tensors: op = t.op - if "workload" in op.attrs: - return args_to_workload(op.attrs["workload"]) wkl = traverse(op.input_tensors) if wkl: return wkl + + if "workload" in op.attrs: + ret = args_to_workload(op.attrs["workload"]) + if task_name is None or ret[0] == task_name: + return ret return None outs = [outs] if isinstance(outs, tensor.Tensor) else outs diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 4c25255fd7b3..ee732e604499 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -53,6 +53,15 @@ def wrapper(attrs, outs, target): return wrapper +def wrap_topi_compute(topi_compute): + """Wrap TOPI compute which doesn't use attrs""" + + def wrapper(attrs, inputs, out_type): + return [topi_compute(*inputs)] + + return wrapper + + def get_conv2d_in_channels(data_shape, data_layout): """Get conv2d input channels""" data_shape = get_const_tuple(data_shape) diff --git a/python/tvm/relay/testing/tf.py b/python/tvm/relay/testing/tf.py index b0b15775ebda..d20c0e0ab9dd 100644 --- a/python/tvm/relay/testing/tf.py +++ b/python/tvm/relay/testing/tf.py @@ -32,7 +32,7 @@ try: tf_compat_v1 = tf.compat.v1 -except ImportError: +except (ImportError, AttributeError): tf_compat_v1 = tf ###################################################################### diff --git a/python/tvm/topi/x86/bitserial_dense.py b/python/tvm/topi/x86/bitserial_dense.py index 7af18f602234..4b2ee11fe2e1 100644 --- a/python/tvm/topi/x86/bitserial_dense.py +++ b/python/tvm/topi/x86/bitserial_dense.py @@ -122,7 +122,7 @@ def bitserial_dense( return matmul -@autotvm.register_topi_schedule("biserial_dense.x86") +@autotvm.register_topi_schedule("bitserial_dense.x86") def schedule_bitserial_dense(cfg, outs): """Schedule for bitserial_dense. diff --git a/src/relay/backend/compile_engine.cc b/src/relay/backend/compile_engine.cc index 0777b19ec557..5e3b66b3ae15 100644 --- a/src/relay/backend/compile_engine.cc +++ b/src/relay/backend/compile_engine.cc @@ -251,7 +251,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> << "Cannot apply TOPI schedule to a primitive function with two complicated ops" << " anchor=" << anchor_op_ << " current=" << op; } - if (op_pattern >= anchor_op_pattern_) { + if (op_pattern > anchor_op_pattern_) { anchor_op_ = op; anchor_attrs_ = call_node->attrs; anchor_op_pattern_ = op_pattern; @@ -309,7 +309,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator> tvm::Target target_; Op anchor_op_; Attrs anchor_attrs_; - int anchor_op_pattern_{0}; + int anchor_op_pattern_{-1}; OpImplementation anchor_implementation_; std::ostringstream readable_name_stream_; Array scalars_; diff --git a/src/runtime/workspace_pool.cc b/src/runtime/workspace_pool.cc index 40d488df700a..6ed5bf4daba6 100644 --- a/src/runtime/workspace_pool.cc +++ b/src/runtime/workspace_pool.cc @@ -115,7 +115,6 @@ class WorkspacePool::Pool { } // Release all resources void Release(Device dev, DeviceAPI* device) { - ICHECK_EQ(allocated_.size(), 1); for (size_t i = 1; i < free_list_.size(); ++i) { device->FreeDataSpace(dev, free_list_[i].data); } diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index e39a52597f27..c40fd7edfdc2 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -109,16 +109,6 @@ class BuiltinLower : public StmtExprMutator { op = stmt.as(); // Get constant allocation bound. int64_t nbytes = GetVectorBytes(op->dtype); - if (device_type_.defined()) { - if (const auto* dev_type = device_type_.as()) { - if (dev_type->value == kDLCPU) { - int32_t constant_size = op->constant_allocation_size(); - if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { - return stmt; - } - } - } - } PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); for (size_t i = 0; i < op->extents.size(); ++i) { total_bytes = total_bytes * op->extents[i]; diff --git a/vta/python/vta/autotvm.py b/vta/python/vta/autotvm.py index 9aa7390f238f..285e30923b13 100644 --- a/vta/python/vta/autotvm.py +++ b/vta/python/vta/autotvm.py @@ -46,7 +46,7 @@ def reprogram_fpga(remote, _build_result): _build_result : tvm.autotvm.measure.measure_methods.BuildResult Artifact from the build phase, unused here. """ - rpc_client.program_bitstream(remote, bitstream) + rpc_client.program_fpga(remote, bitstream) rpc_client.reconfig_runtime(remote) return default_module_loader(reprogram_fpga) diff --git a/vta/python/vta/environment.py b/vta/python/vta/environment.py index b334e013e5cf..4b6e5bdeca78 100644 --- a/vta/python/vta/environment.py +++ b/vta/python/vta/environment.py @@ -66,11 +66,13 @@ class DevContext(object): MEM_ID_INP = 2 MEM_ID_ACC = 3 MEM_ID_OUT = 4 + MEM_ID_ACC_8BIT = 5 # VTA ALU Opcodes ALU_OPCODE_MIN = 0 ALU_OPCODE_MAX = 1 ALU_OPCODE_ADD = 2 ALU_OPCODE_SHR = 3 + ALU_OPCODE_MUL = 4 # Task queue id (pipeline stage) QID_LOAD_INP = 1 QID_LOAD_WGT = 1 @@ -232,7 +234,7 @@ def target_host(self): return "llvm -mtriple=armv7-none-linux-gnueabihf" if self.TARGET == "ultra96": return "llvm -mtriple=aarch64-linux-gnu" - if self.TARGET in ["sim", "tsim"]: + if self.TARGET in ["sim", "tsim", "intelfocl"]: return "llvm" raise ValueError("Unknown target %s" % self.TARGET) diff --git a/vta/python/vta/program_bitstream.py b/vta/python/vta/program_bitstream.py index 556933ac6e5a..a7da89d2f637 100644 --- a/vta/python/vta/program_bitstream.py +++ b/vta/python/vta/program_bitstream.py @@ -57,7 +57,17 @@ def de10nano_bitstream_program(bitstream_path): program(bitstream_path) -def bitstream_program(target, bitstream): +def intelfocl_bitstream_program(bitstream_path, mem_size=4 * 1024 * 1024 * 1024): + # pylint: disable=import-outside-toplevel + from tvm import get_global_func + + program = get_global_func("vta.oclfpga.program") + program(bitstream_path, mem_size) + + +def bitstream_program(target, bitstream, *args): + """program bitstream to devices""" + if target in ["pynq", "ultra96"]: pynq_bitstream_program(bitstream) elif target in ["de10nano"]: @@ -65,6 +75,8 @@ def bitstream_program(target, bitstream): elif target in ["sim", "tsim"]: # In simulation, bit stream programming is a no-op return + elif target in ["intelfocl"]: + intelfocl_bitstream_program(bitstream, *args) else: raise RuntimeError("Unknown target {}".format(target)) diff --git a/vta/python/vta/rpc_client.py b/vta/python/vta/rpc_client.py index 02ff8be00b81..90203983987a 100644 --- a/vta/python/vta/rpc_client.py +++ b/vta/python/vta/rpc_client.py @@ -17,6 +17,8 @@ """VTA RPC client function""" import os +from tvm import rpc +from vta import program_bitstream from .environment import get_env from .bitstream import download_bitstream, get_bitstream_path @@ -45,16 +47,20 @@ def program_fpga(remote, bitstream=None): bitstream : str, optional Path to a local bistream file. If unset, tries to download from cache server. """ + env = get_env() + if bitstream: assert os.path.isfile(bitstream) else: bitstream = get_bitstream_path() if not os.path.isfile(bitstream): - env = get_env() if env.TARGET == "de10nano": return download_bitstream() - fprogram = remote.get_function("tvm.contrib.vta.init") - remote.upload(bitstream) - fprogram(os.path.basename(bitstream)) + if isinstance(remote, rpc.LocalSession): + program_bitstream.bitstream_program(env.TARGET, bitstream) + else: + fprogram = remote.get_function("tvm.contrib.vta.init") + remote.upload(bitstream) + fprogram(os.path.basename(bitstream)) diff --git a/vta/python/vta/testing/simulator.py b/vta/python/vta/testing/simulator.py index 05f37c182815..2b662beb4672 100644 --- a/vta/python/vta/testing/simulator.py +++ b/vta/python/vta/testing/simulator.py @@ -27,7 +27,13 @@ def _load_sw(): """Load hardware library for simulator.""" env = get_env() - lib_driver_name = "libvta_tsim" if env.TARGET == "tsim" else "libvta_fsim" + lib_driver_name = ( + "libvta_tsim" + if env.TARGET == "tsim" + else "libvta" + if env.TARGET == "intelfocl" + else "libvta_fsim" + ) require_sim = env.TARGET in ("sim", "tsim") libs = [] diff --git a/vta/python/vta/testing/utils.py b/vta/python/vta/testing/utils.py index 99d8d40df99c..f163359667f1 100644 --- a/vta/python/vta/testing/utils.py +++ b/vta/python/vta/testing/utils.py @@ -32,7 +32,7 @@ def run(run_func): """ env = get_env() - if env.TARGET in ["sim", "tsim"]: + if env.TARGET in ["sim", "tsim", "intelfocl"]: # Talk to local RPC if necessary to debug RPC server. # Compile vta on your host with make at the root. # Make sure TARGET is set to "sim" in the config.json file. diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 8998f5712381..5ec11677da70 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -423,7 +423,7 @@ def visit_call(self, call): self.start_pack and call.op == op.op.get("cast") and input_types[0].dtype == "int32" ): cast = relay.Call(op.op.get("cast"), [args[0]], call.attrs) - return relay.Call(op.op.get("copy"), [cast]) + return cast elif call.op == self.pad: pad_width = call.attrs.pad_width if len(pad_width) == 6: diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py index a217104a9ae7..f243c3fc2c89 100644 --- a/vta/python/vta/top/op.py +++ b/vta/python/vta/top/op.py @@ -20,6 +20,7 @@ import tvm from tvm import te +from tvm import autotvm from tvm import topi from tvm.relay.op import op as reg @@ -33,6 +34,7 @@ from .vta_dense import dense_packed, schedule_dense_packed from ..environment import get_env +ENV = get_env() # override to force partition at copy reg.register_pattern("copy", OpPattern.INJECTIVE, level=15) @@ -64,6 +66,137 @@ def clip_strategy_vta(attrs, inputs, out_type, target): reg.get("clip").get_attr("FTVMStrategy").register(clip_strategy_vta, "vta") +@autotvm.register_topi_compute("add.vta") +def add_packed(cfg, lhs, rhs): + return topi.add(lhs, rhs) + + +@autotvm.register_topi_compute("multiply.vta") +def multiply_packed(cfg, lhs, rhs): + return topi.multiply(lhs, rhs) + + +def schedule_alu_packed(cfg, outs): + """alu packed schedule""" + assert len(outs) == 1 + + def is_cast_op(op): + return op.name == "T_cast" + + outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs + output = outs[0] + s = te.create_schedule([x.op for x in outs]) + te.schedule.AutoInlineInjective(s) + + # other target does not support alu-only ops + if not (ENV.TARGET in ["sim", "tsim", "intelfocl"]): + return s + + # only put the int-related ops to vta + if "int" in output.dtype and len(output.shape) == 6: + ewise_inputs = [] + ewise_ops = [] + const_ops = [] + + def _traverse(op): + if topi.tag.is_broadcast(op.tag): + if not op.same_as(output.op): + if not op.axis: + const_ops.append(op) + elif not is_cast_op(op): + ewise_ops.append(op) + + for tensor in op.input_tensors: + if isinstance(tensor.op, tvm.te.PlaceholderOp): + ewise_inputs.append((op, tensor)) + elif is_cast_op(tensor.op) and not op.same_as(output.op): + ewise_inputs.append((op, tensor)) + else: + _traverse(tensor.op) + else: + for tensor in op.input_tensors: + if (not isinstance(tensor.op, tvm.te.PlaceholderOp)) and ( + not is_cast_op(tensor.op) + ): + _traverse(tensor.op) + + op = output.op + _traverse(op) + for _, t in ewise_inputs: + if t.dtype == "float32": + return s + + x_bo, x_co, x_i, x_j, x_bi, x_ci = s[output].op.axis + + cfg.define_split("tile_co", x_co, num_outputs=2) + cfg.define_split("tile_h", x_i, num_outputs=2) + cfg.define_split("tile_w", x_j, num_outputs=2) + + x_co0, x_co1 = cfg["tile_co"].apply(s, output, x_co) + x_i0, x_i1 = cfg["tile_h"].apply(s, output, x_i) + x_j0, x_j1 = cfg["tile_w"].apply(s, output, x_j) + s[output].reorder(x_bo, x_i0, x_co0, x_j0, x_co1, x_i1, x_j1, x_bi, x_ci) + store_pt = x_j0 + + for e_o in ewise_ops: + s[e_o].set_scope(ENV.acc_scope) + s[e_o].pragma(s[e_o].op.axis[0], ENV.alu) + s[e_o].compute_at(s[output], store_pt) + + # cache read input + cache_read_ewise = [] + for consumer, tensor in ewise_inputs: + cache_read_ewise.append(s.cache_read(tensor, ENV.acc_scope, [consumer])) + + for tensor in cache_read_ewise: + if s[tensor].op.axis: + s[tensor].pragma(s[tensor].op.axis[0], ENV.dma_copy) + s[tensor].compute_at(s[output], store_pt) + + for op in const_ops: + s[op].compute_inline() + + s[output].pragma(x_co1, ENV.dma_copy) + + return s + + +@autotvm.register_topi_schedule("add.vta") +def schedule_add_packed(cfg, outs): + return schedule_alu_packed(cfg, outs) + + +@autotvm.register_topi_schedule("multiply.vta") +def schedule_multiply_packed(cfg, outs): + return schedule_alu_packed(cfg, outs) + + +def add_strategy_vta(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + _strategy.wrap_topi_compute(add_packed), + _strategy.wrap_topi_schedule(schedule_add_packed), + name="add.vta", + ) + return strategy + + +def multiply_strategy_vta(attrs, inputs, out_type, target): + strategy = OpStrategy() + strategy.add_implementation( + _strategy.wrap_topi_compute(multiply_packed), + _strategy.wrap_topi_schedule(schedule_multiply_packed), + name="multiply.vta", + ) + return strategy + + +# other target does not support alu-only ops +if ENV.TARGET in ["sim", "intelfocl"]: + reg.get("add").get_attr("FTVMStrategy").register(add_strategy_vta, "vta") + reg.get("multiply").get_attr("FTVMStrategy").register(multiply_strategy_vta, "vta") + + @_strategy.conv2d_strategy.register("vta") def conv2d_strategy_vta(attrs, inputs, out_type, target): """conv2d vta strategy""" @@ -76,9 +209,8 @@ def conv2d_strategy_vta(attrs, inputs, out_type, target): assert dilation == (1, 1), "support for dilation limited to (1, 1)" if is_packed_layout(layout): if groups == 1: - env = get_env() - assert env.LOG_INP_WIDTH == 3, "only support 8bit inp for now" - assert env.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now" + assert ENV.LOG_INP_WIDTH == 3, "only support 8bit inp for now" + assert ENV.LOG_WGT_WIDTH == 3, "only support 8bit wgt for now" assert kernel.dtype == "int8" strategy.add_implementation( diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py index 9770857fb0b9..f8b4f2d2c5c3 100644 --- a/vta/python/vta/transform.py +++ b/vta/python/vta/transform.py @@ -409,8 +409,6 @@ def _fold_buffer_dim(buf, scope, elem_block): def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold): elem_block = elem_bytes * 8 // elem_width - if buf.dtype != dtype: - raise RuntimeError("Expect buffer type to be %s instead of %s" % (dtype, buf.dtype)) shape, strides = buf.shape, buf.strides if not utils.equal_const_int(idxm(buf.elem_offset, elem_block), 0): raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block)) @@ -591,6 +589,10 @@ def _inject_copy(src, dst, pad_before, pad_after, pad_value): src, elem_width, elem_bytes, data_type, dst.scope, allow_fold=allow_fold ) + if data_type != src.dtype: + assert data_type == "int%d" % env.ACC_WIDTH and src.dtype == "int%d" % env.INP_WIDTH + mem_type = env.dev.MEM_ID_ACC_8BIT + irb = tvm.tir.ir_builder.create() irb.scope_attr(env.dev.vta_axis, "coproc_scope", env.dev.get_task_qid(task_qid)) diff --git a/vta/runtime/runtime.cc b/vta/runtime/runtime.cc index 49fe9c557336..c1215214cf51 100644 --- a/vta/runtime/runtime.cc +++ b/vta/runtime/runtime.cc @@ -27,6 +27,8 @@ #include "runtime.h" #include +#include +#include #include #include #include @@ -35,6 +37,9 @@ #include #include #include +#include +#include +#include #include namespace vta { @@ -47,10 +52,84 @@ static const bool kBufferCoherent = VTA_COHERENT_ACCESSES; /*! \brief Always cache buffers (otherwise, write back to DRAM from CPU) */ static const bool kAlwaysCache = true; +template +class AlignmentAllocator : public std::allocator { + public: + typedef T value_type; + typedef std::size_t size_type; + typedef std::ptrdiff_t difference_type; + + typedef T* pointer; + typedef const T* const_pointer; + + typedef T& reference; + typedef const T& const_reference; + + inline AlignmentAllocator() throw() {} + + template + inline AlignmentAllocator(const AlignmentAllocator&) throw() {} + + inline ~AlignmentAllocator() throw() {} + + inline pointer address(reference r) { return &r; } + + inline const_pointer address(const_reference r) const { return &r; } + + inline pointer allocate(size_type n) { return (pointer)memalign(N, n * sizeof(value_type)); } + + inline void deallocate(pointer p, size_type) { free(p); } + + inline void construct(pointer p, const value_type& wert) { new (p) value_type(wert); } + + inline void destroy(pointer p) { p->~value_type(); } + + inline size_type max_size() const throw() { return size_type(-1) / sizeof(value_type); } + + template + struct rebind { + typedef AlignmentAllocator other; + }; + + bool operator!=(const AlignmentAllocator& other) const { return !(*this == other); } + + // Returns true if and only if storage allocated from *this + // can be deallocated from other, and vice versa. + // Always returns true for stateless allocators. + bool operator==(const AlignmentAllocator& other) const { return true; } +}; + +class DeviceAllocStat { + public: + void AddAlloc(const void* ptr) { + std::lock_guard lock(mtx_); + allocated_.insert(ptr); + } + + bool CheckAlloc(const void* ptr) { + std::lock_guard lock(mtx_); + return allocated_.count(ptr); + } + + void DelAlloc(const void* ptr) { + std::lock_guard lock(mtx_); + allocated_.erase(ptr); + } + + private: + std::set allocated_; + std::mutex mtx_; +}; + +// here we use a global variable to memorize the allocation stats +static std::shared_ptr alloc_stat(new DeviceAllocStat()); + /*! * \brief Data buffer represents data on CMA. */ struct DataBuffer { + DataBuffer() { alloc_stat_ = alloc_stat; } + /*! \return Virtual address of the data. */ void* virt_addr() const { return data_; } /*! \return Physical address of the data. */ @@ -101,6 +180,8 @@ struct DataBuffer { DataBuffer* buffer = new DataBuffer(); buffer->data_ = data; buffer->phy_addr_ = VTAMemGetPhyAddr(data); + + alloc_stat->AddAlloc(buffer); return buffer; } /*! @@ -108,6 +189,7 @@ struct DataBuffer { * \param buffer The buffer to be freed. */ static void Free(DataBuffer* buffer) { + alloc_stat->DelAlloc(buffer); VTAMemFree(buffer->data_); delete buffer; } @@ -117,7 +199,11 @@ struct DataBuffer { * \return The corresponding data buffer header. */ static DataBuffer* FromHandle(const void* buffer) { - return const_cast(reinterpret_cast(buffer)); + if (alloc_stat->CheckAlloc(buffer)) { + return const_cast(reinterpret_cast(buffer)); + } else { + return nullptr; + } } private: @@ -125,6 +211,11 @@ struct DataBuffer { void* data_; /*! \brief The physical address of the buffer, excluding header. */ vta_phy_addr_t phy_addr_; + + // a copy of global shared_ptr instance + // to avoid the global instance is destructed before there are still some pending DataBuffers not + // destructed + std::shared_ptr alloc_stat_; }; /*! @@ -329,7 +420,7 @@ class BaseQueue { // End location of current SRAM write in FIFO mode uint32_t sram_end_{0}; // The buffer in DRAM - std::vector dram_buffer_; + std::vector> dram_buffer_; // FPGA accessible buffer void* fpga_buff_{NULL}; // Physical address of the FPGA buffer @@ -429,14 +520,24 @@ class UopQueue : public BaseQueue { buff_size += cache_[i]->size() * kElemBytes; } CHECK(buff_size <= kMaxBytes); - // Move kernel contents to FPGA readable buffer + + // merge all the cache entries and do CopyFromHost once + uint32_t total_size = 0; + for (uint32_t i = 0; i < cache_.size(); ++i) { + uint32_t ksize = cache_[i]->size() * kElemBytes; + total_size += ksize; + } + + char* lbuf = (char*)memalign(ALLOC_ALIGNMENT, total_size); uint32_t offset = 0; for (uint32_t i = 0; i < cache_.size(); ++i) { uint32_t ksize = cache_[i]->size() * kElemBytes; - VTAMemCopyFromHost(static_cast(fpga_buff_) + offset, cache_[i]->data(), ksize); - // Update offset + memcpy(lbuf + offset, cache_[i]->data(), ksize); offset += ksize; } + VTAMemCopyFromHost(static_cast(fpga_buff_), lbuf, total_size); + free(lbuf); + // Flush if we're using a shared memory system // and if interface is non-coherent if (!coherent_ && always_cache_) { @@ -631,6 +732,8 @@ class InsnQueue : public BaseQueue { } } else if (opcode == VTA_ALU_OPCODE_SHR) { return "shr"; + } else if (opcode == VTA_ALU_OPCODE_MUL) { + return "mul"; } return "unknown op"; @@ -830,7 +933,7 @@ class InsnQueue : public BaseQueue { } // Get stage of the memory static PipelineStage GetMemPipelineStage(int memory_type) { - if (memory_type == VTA_MEM_ID_ACC) return kComputeStage; + if (memory_type == VTA_MEM_ID_ACC || memory_type == VTA_MEM_ID_ACC_8BIT) return kComputeStage; if (memory_type == VTA_MEM_ID_UOP) return kComputeStage; return kLoadStage; } @@ -840,7 +943,8 @@ class InsnQueue : public BaseQueue { if (insn->opcode == VTA_OPCODE_ALU) return kComputeStage; if (insn->opcode == VTA_OPCODE_LOAD) { if (insn->x_size == 0) return kNoneStage; - if (insn->memory_type == VTA_MEM_ID_ACC) return kComputeStage; + if (insn->memory_type == VTA_MEM_ID_ACC || insn->memory_type == VTA_MEM_ID_ACC_8BIT) + return kComputeStage; if (insn->memory_type == VTA_MEM_ID_UOP) return kComputeStage; return kLoadStage; } @@ -923,6 +1027,9 @@ class CommandQueue { case VTA_MEM_ID_OUT: elem_bytes = VTA_OUT_ELEM_BYTES; break; + case VTA_MEM_ID_ACC_8BIT: + elem_bytes = VTA_ACC_ELEM_BYTES / 4; + break; default: LOG(FATAL) << "Memory id not recognized:" << memory_id; break; @@ -1022,7 +1129,7 @@ class CommandQueue { VTA_OPCODE_FINISH); // Make sure that we don't exceed contiguous physical memory limits - CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) < VTA_MAX_XFER); + CHECK(insn_queue_.count() * sizeof(VTAGenericInsn) <= VTA_MAX_XFER); int timeout = VTADeviceRun(device_, insn_queue_.dram_phy_addr(), insn_queue_.count(), wait_cycles); CHECK_EQ(timeout, 0); @@ -1170,8 +1277,8 @@ class CommandQueue { void CheckInsnOverFlow() { // At each API call, we can at most commit: - // one pending store, one pending load, and one uop - if ((insn_queue_.count() + 4) * sizeof(VTAGenericInsn) >= VTA_MAX_XFER) { + // at most: 2 NOP-COMPUTE-STAGE -> 2 NOP-MEMORY-STAGE -> 1 NOP-COMPUTE-STAGE -> 1 FINISH + if ((insn_queue_.count() + 6) * sizeof(VTAGenericInsn) > VTA_MAX_XFER) { this->AutoSync(); } } @@ -1232,7 +1339,12 @@ void VTASetDebugMode(VTACommandHandle cmd, int debug_flag) { } void* VTABufferCPUPtr(VTACommandHandle cmd, void* buffer) { - return vta::DataBuffer::FromHandle(buffer)->virt_addr(); + auto data_buf = vta::DataBuffer::FromHandle(buffer); + if (data_buf) { + return data_buf->virt_addr(); + } else { // it is a raw ptr allocated by CPU + return buffer; + } } void VTAWriteBarrier(VTACommandHandle cmd, void* buffer, uint32_t elem_bits, uint32_t start, diff --git a/vta/runtime/runtime.h b/vta/runtime/runtime.h index 24ebb8e1247b..e6a6cb26528e 100644 --- a/vta/runtime/runtime.h +++ b/vta/runtime/runtime.h @@ -42,6 +42,8 @@ extern "C" { #define VTA_DEBUG_SKIP_WRITE_BARRIER (1 << 4) #define VTA_DEBUG_FORCE_SERIAL (1 << 5) +#define ALLOC_ALIGNMENT 64 + /*! * \brief Allocate data buffer. * \param size Buffer size. diff --git a/vta/tests/python/integration/test_benchmark_topi_conv2d.py b/vta/tests/python/integration/test_benchmark_topi_conv2d.py index b82c3a90c9d0..daf9b4a7f022 100644 --- a/vta/tests/python/integration/test_benchmark_topi_conv2d.py +++ b/vta/tests/python/integration/test_benchmark_topi_conv2d.py @@ -292,7 +292,7 @@ def test_conv2d(device): def _run(env, remote): if device == "vta": target = env.target - if env.TARGET not in ["sim", "tsim"]: + if env.TARGET not in ["sim", "tsim", "intelfocl"]: assert tvm.runtime.enabled("rpc") program_fpga(remote, bitstream=None) reconfig_runtime(remote) diff --git a/vta/tutorials/autotvm/tune_alu_vta.py b/vta/tutorials/autotvm/tune_alu_vta.py new file mode 100644 index 000000000000..f2bf15b9876f --- /dev/null +++ b/vta/tutorials/autotvm/tune_alu_vta.py @@ -0,0 +1,320 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +Auto-tuning a ALU fused op on VTA +""" + +import os +from mxnet.gluon.model_zoo import vision +import numpy as np +from PIL import Image + +from tvm import topi +import tvm +from tvm import te +from tvm import rpc, autotvm, relay +from tvm.contrib import graph_runtime, download +from tvm.autotvm.measure.measure_methods import request_remote +from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner +from tvm.autotvm import record + +import vta +from vta.testing import simulator +from vta.top import graph_pack +import copy + + +################################################################# +# Compile network +# --------------- +# Perform vta-specific compilation with Relay from a Gluon model +def compile_network(env, target, model, start_pack, stop_pack): + + # Populate the shape and data type dictionary + dtype_dict = {"data": "float32"} + shape_dict = {"data": (env.BATCH, 3, 224, 224)} + + # Get off the shelf gluon model, and convert to relay + gluon_model = vision.get_model(model, pretrained=True) + mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict) + + # Update shape and type dictionary + shape_dict.update({k: v.shape for k, v in params.items()}) + dtype_dict.update({k: str(v.dtype) for k, v in params.items()}) + + # Perform quantization in Relay + # Note: We set opt_level to 3 in order to fold batch norm + with relay.build_config(opt_level=3): + with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]): + mod = relay.quantize.quantize(mod, params=params) + + # Perform graph packing and constant folding for VTA target + if target.device_name == "vta": + assert env.BLOCK_IN == env.BLOCK_OUT + relay_prog = graph_pack( + mod["main"], + env.BATCH, + env.BLOCK_OUT, + env.WGT_WIDTH, + start_name=start_pack, + stop_name=stop_pack, + ) + + return relay_prog, params + + +########################################### +# Set Tuning Options +# ------------------ +# Before tuning, we should apply some configurations. +# Here we use an Pynq-Z1 board as an example. + +# Tracker host and port can be set by your environment +tracker_host = os.environ.get("TVM_TRACKER_HOST", "0.0.0.0") +tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190)) + +# Load VTA parameters from the vta/config/vta_config.json file +env = vta.get_env() + +# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device. +# Set ``device=arm_cpu`` to run inference on the CPU +# or ``device=vta`` to run inference on the FPGA. +device = "vta" +target = env.target if device == "vta" else env.target_vta_cpu + +# Name of Gluon model to compile +# The ``start_pack`` and ``stop_pack`` labels indicate where +# to start and end the graph packing relay pass: in other words +# where to start and finish offloading to VTA. +network = "resnet50_v2" +start_pack = "nn.max_pool2d" +stop_pack = "nn.global_avg_pool2d" + +# Tuning option +log_file = "%s.alu.%s.log" % (device, network) +tuning_option = { + "log_filename": log_file, + "tuner": "random", + "n_trial": 1000, + "early_stopping": None, + "measure_option": autotvm.measure_option( + builder=autotvm.LocalBuilder(n_parallel=1), + runner=autotvm.RPCRunner( + env.TARGET, + host=tracker_host, + port=tracker_port, + number=5, + timeout=60, + # check_correctness=True, # TODO: re-enable when check_correctness works again. + ), + ), +} + + +def log_to_file(file_out, protocol="json"): + """Log the tuning records into file. + The rows of the log are stored in the format of autotvm.record.encode. + for lhs == rhs, we add an extra rhs = [] record + + Parameters + ---------- + file_out : str + The file to log to. + protocol: str, optional + The log protocol. Can be 'json' or 'pickle' + + Returns + ------- + callback : callable + Callback function to do the logging. + """ + + def _callback(_, inputs, results): + with open(file_out, "a") as f: + for inp, result in zip(inputs, results): + f.write(record.encode(inp, result, protocol) + "\n") + + # we only consider task with same lhs and rhs + if inp.task.args[0] == inp.task.args[1]: + args = list(inp.task.args) + args[1] = (args[0][0], (), args[0][2]) + inp_copy = copy.deepcopy(inp) + inp_copy.task.args = tuple(args) + f.write(record.encode(inp_copy, result, protocol) + "\n") + + return _callback + + +def tune_tasks( + tasks, + measure_option, + tuner="xgb", + n_trial=10, + early_stopping=None, + log_filename="tuning.log", + use_transfer_learning=True, +): + + # create tmp log file + tmp_log_file = log_filename + ".tmp" + if os.path.exists(tmp_log_file): + os.remove(tmp_log_file) + + for i, tsk in enumerate(reversed(tasks)): + prefix = "[Task %2d/%2d] " % (i + 1, len(tasks)) + + # create tuner + if tuner == "xgb" or tuner == "xgb-rank": + tuner_obj = XGBTuner(tsk, loss_type="rank") + elif tuner == "xgb_knob": + tuner_obj = XGBTuner(tsk, loss_type="rank", feature_type="knob") + elif tuner == "ga": + tuner_obj = GATuner(tsk, pop_size=50) + elif tuner == "random": + tuner_obj = RandomTuner(tsk) + elif tuner == "gridsearch": + tuner_obj = GridSearchTuner(tsk) + else: + raise ValueError("Invalid tuner: " + tuner) + + if use_transfer_learning: + if os.path.isfile(tmp_log_file): + tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file)) + + # do tuning + tsk_trial = min(n_trial, len(tsk.config_space)) + tuner_obj.tune( + n_trial=tsk_trial, + early_stopping=early_stopping, + measure_option=measure_option, + callbacks=[ + autotvm.callback.progress_bar(tsk_trial, prefix=prefix), + log_to_file(tmp_log_file), + ], + ) + + # pick best records to a cache file + autotvm.record.pick_best(tmp_log_file, log_filename) + os.remove(tmp_log_file) + + +######################################################################## +# Register VTA-specific tuning tasks +def register_vta_tuning_tasks(): + from tvm.autotvm.task import TaskExtractEnv + + @tvm.te.tag_scope(tag=topi.tag.ELEMWISE) + def my_clip(x, a_min, a_max): + """Unlike topi's current clip, put min and max into two stages.""" + const_min = tvm.tir.const(a_min, x.dtype) + const_max = tvm.tir.const(a_max, x.dtype) + x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA") + x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB") + return x + + # init autotvm env to register VTA operator + TaskExtractEnv() + + @autotvm.template("add.vta") + def _topi_add(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + A, B = args[:2] + + with tvm.target.vta(): + res = vta.top.op.add_packed(*args, **kwargs) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.Target.current().device_name == "vta": + s = vta.top.op.schedule_add_packed([res]) + else: + s = te.create_schedule([res.op]) + return s, [A, B, res] + + @autotvm.template("multiply.vta") + def _topi_multiply(*args, **kwargs): + assert not kwargs, "Do not support kwargs in template function call" + A, B = args[:2] + + with tvm.target.vta(): + res = vta.top.op.multiply_packed(*args, **kwargs) + res = my_clip(res, 0, 127) + res = topi.cast(res, "int8") + + if tvm.target.Target.current().device_name == "vta": + s = vta.top.op.schedule_multiply_packed([res]) + else: + s = te.create_schedule([res.op]) + return s, [A, B, res] + + +######################################################################## +# Finally, we launch tuning jobs and evaluate the end-to-end performance. +def tune_and_evaluate(tuning_opt): + + if env.TARGET != "intelfocl": + print("ALU only op only available for intelfocl target") + return + + # Register VTA tuning tasks + register_vta_tuning_tasks() + + # Perform task extraction on Relay program + print("Extract tasks...") + relay_prog, params = compile_network(env, target, network, start_pack, stop_pack) + mod = tvm.IRModule.from_expr(relay_prog) + tasks = autotvm.task.extract_from_program( + mod, + params=params, + ops=( + relay.op.get("add"), + relay.op.get("multiply"), + ), + target=target, + target_host=env.target_host, + ) + + # filter out non-packed alu task + tasks = list(filter(lambda t: len(t.args[0][1]) > 4, tasks)) + # filter out float alu task + tasks = list(filter(lambda t: t.args[0][2] != "float32", tasks)) + + # We should have extracted 10 convolution tasks + tasks_set = {} + print("Extracted {} alu tasks:".format(len(tasks))) + for tsk in tasks: + print("tsk = ", tsk) + + if len(tsk.args[1][1]) == 0: + args = list(tsk.args) + args[1] = args[0] + tsk.args = tuple(args) + + if (tsk.name, tsk.args) in tasks_set: + print("task {} already exists".format(tsk)) + tasks_set[(tsk.name, tsk.args)] = tsk + + tasks = list(tasks_set.values()) + print("After merged, final #tasks={}, tasks = {}".format(len(tasks), tasks)) + + # run tuning tasks + print("Tuning...") + tune_tasks(tasks, **tuning_opt) + + +# Run the tuning and evaluate the results +tune_and_evaluate(tuning_option) diff --git a/vta/tutorials/frontend/deploy_classification.py b/vta/tutorials/frontend/deploy_classification.py index f9db824eafa3..b72301d60c0c 100644 --- a/vta/tutorials/frontend/deploy_classification.py +++ b/vta/tutorials/frontend/deploy_classification.py @@ -52,7 +52,7 @@ import tvm from tvm import te from tvm import rpc, autotvm, relay -from tvm.contrib import graph_executor, utils, download +from tvm.contrib import graph_executor, utils, download, graph_runtime from tvm.contrib.debugger import debug_executor from tvm.relay import transform @@ -60,6 +60,7 @@ from vta.testing import simulator from vta.top import graph_pack + # Make sure that TVM was compiled with RPC=1 assert tvm.runtime.enabled("rpc") @@ -99,7 +100,7 @@ # When target is 'pynq', reconfigure FPGA and runtime. # Otherwise, if target is 'sim', execute locally. -if env.TARGET not in ["sim", "tsim"]: +if env.TARGET not in ["sim", "tsim", "intelfocl"]: # Get remote from tracker node if environment variable is set. # To set up the tracker, you'll need to follow the "Auto-tuning @@ -131,6 +132,10 @@ else: remote = rpc.LocalSession() + if env.TARGET in ["intelfocl"]: + # program intelfocl aocx + vta.program_fpga(remote, bitstream="vta.bitstream") + # Get execution context from remote ctx = remote.ext_dev(0) if device == "vta" else remote.cpu(0) @@ -178,6 +183,7 @@ mod = relay.quantize.quantize(mod, params=params) # Perform graph packing and constant folding for VTA target assert env.BLOCK_IN == env.BLOCK_OUT + # do device annotation if target is intelfocl or sim relay_prog = graph_pack( mod["main"], env.BATCH, @@ -185,6 +191,7 @@ env.WGT_WIDTH, start_name=pack_dict[model][0], stop_name=pack_dict[model][1], + device_annot=(env.TARGET == "intelfocl" or env.TARGET == "sim"), ) else: relay_prog = mod["main"] @@ -196,8 +203,13 @@ relay_prog, target=target, params=params, target_host=env.target_host ) else: + if env.TARGET == "intelfocl" or env.TARGET == "sim": + # multiple targets to run both on cpu and vta + target = {"cpu": env.target_vta_cpu, "ext_dev": target} with vta.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): - lib = relay.build(relay_prog, target=target, params=params, target_host=env.target_host) + graph, lib, params = relay.build( + relay_prog, target=target, params=params, target_host=env.target_host + ) # Measure Relay build time build_time = time.time() - build_start @@ -209,8 +221,12 @@ remote.upload(temp.relpath("graphlib.tar")) lib = remote.load_module("graphlib.tar") - # Graph executor - m = graph_executor.GraphModule(lib["default"](ctx)) + if env.TARGET == "intelfocl" or env.TARGET == "sim": + ctxes = [remote.ext_dev(0), remote.cpu(0)] + m = graph_runtime.create(graph, lib, ctxes) + else: + # Graph runtime + m = graph_runtime.create(graph, lib, ctx) ###################################################################### # Perform image classification inference @@ -241,6 +257,7 @@ image = np.repeat(image, env.BATCH, axis=0) # Set the network parameters and inputs +m.set_input(**params) m.set_input("data", image) # Perform inference and gather execution statistics diff --git a/vta/tutorials/vta_get_started.py b/vta/tutorials/vta_get_started.py index 1a097b804a31..f64cae11cccc 100644 --- a/vta/tutorials/vta_get_started.py +++ b/vta/tutorials/vta_get_started.py @@ -91,9 +91,13 @@ vta.program_fpga(remote, bitstream=None) # In simulation mode, host the RPC server locally. -elif env.TARGET == "sim": +elif env.TARGET in ("sim", "tsim", "intelfocl"): remote = rpc.LocalSession() + if env.TARGET in ["intelfocl"]: + # program intelfocl aocx + vta.program_fpga(remote, bitstream="vta.bitstream") + ###################################################################### # Computation Declaration # -----------------------