From 37692bfdcbcd9373ddbfc6403d122011379ccbae Mon Sep 17 00:00:00 2001 From: Zhang Hao Date: Thu, 23 Jul 2020 14:38:52 +0800 Subject: [PATCH] only keep opencl related code --- python/tvm/autotvm/task/space.py | 2 +- src/relay/transforms/device_annotation.cc | 79 ++++++------ src/tir/transforms/lower_tvm_builtin.cc | 13 -- vta/python/vta/top/graphpack.py | 140 +--------------------- vta/runtime/runtime.cc | 1 + vta/tutorials/autotvm/tune_alu_vta.py | 5 +- 6 files changed, 51 insertions(+), 189 deletions(-) diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index 53ed78a7570d..fbf474fc4df7 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -779,7 +779,7 @@ def _add_new_transform(self, space_class, name, axes, policy, **kwargs): return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))] def __len__(self): - if self._length is None or self._length <= 1: + if self._length is None: self._length = int(np.prod([len(x) for x in self.space_map.values()])) return self._length diff --git a/src/relay/transforms/device_annotation.cc b/src/relay/transforms/device_annotation.cc index fe3cfebf7fe3..39cf563f730a 100644 --- a/src/relay/transforms/device_annotation.cc +++ b/src/relay/transforms/device_annotation.cc @@ -386,38 +386,22 @@ class DeviceInfo { } void VisitExpr_(const ConstantNode* cn) final { - device_tag_[cn] = dev_type_; + post_dfs_order_.push_back(std::make_pair(cn, has_copy_)); } void VisitExpr_(const CallNode* call) final { // Skip annotation nodes. if (!IsOnDeviceNode(call)) { - if (const auto* node = GetDeviceCopyNode(call)) { - CHECK(node->IsInstance()); - const auto* call_node = static_cast(node); - auto attrs = call_node->attrs.as(); - + if (GetDeviceCopyNode(call)) { num_device_copy_ops_++; bool has_copy_prev = has_copy_; has_copy_ = true; - dev_type_ = attrs->src_dev_type; - for (auto& arg : call->args) { - Visit(arg); - // restore the type for remaining arguments - dev_type_ = attrs->src_dev_type; - } - device_tag_[call] = attrs->dst_dev_type; - // update the out_dev_type_, which should be the dst_dev_type of last copy - out_dev_type_ = attrs->dst_dev_type; + ExprVisitor::VisitExpr_(call); + post_dfs_order_.push_back(std::make_pair(call, has_copy_)); has_copy_ = has_copy_prev; } else { - for (auto& arg : call->args) { - int cur_dev_type = dev_type_; - Visit(arg); - // restore the type for remaining arguments - dev_type_ = cur_dev_type; - } - device_tag_[call] = dev_type_; + ExprVisitor::VisitExpr_(call); + post_dfs_order_.push_back(std::make_pair(call, has_copy_)); } } } @@ -430,24 +414,22 @@ class DeviceInfo { void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); } void VisitExpr_(const VarNode* vn) final { - device_tag_[vn] = dev_type_; + post_dfs_order_.push_back(std::make_pair(vn, has_copy_)); } void VisitExpr_(const LetNode* ln) final { ExprVisitor::VisitExpr_(ln); - device_tag_[ln] = dev_type_; + post_dfs_order_.push_back(std::make_pair(ln, has_copy_)); } void VisitExpr_(const IfNode* in) final { ExprVisitor::VisitExpr_(in); - device_tag_[in] = dev_type_; + post_dfs_order_.push_back(std::make_pair(in, has_copy_)); } int num_device_copy_ops_{0}; bool has_copy_ = false; - int dev_type_ = -1; - int out_dev_type_ = -1; - std::unordered_map device_tag_; + std::vector> post_dfs_order_; friend DeviceInfo; }; @@ -473,14 +455,39 @@ class DeviceInfo { } void PropagateDeviceId() { - int out_dev_type = post_visitor_.out_dev_type_; - for (auto& it : post_visitor_.device_tag_) { - if (it.second != -1) { - device_map_.Set(GetRef(it.first), it.second); - } else { - device_map_.Set(GetRef(it.first), out_dev_type); + // Bottom-up propagation. + int out_dev_type = BottomUpPropagation(); + // propagation for remained nodes. + FillPropagation(out_dev_type); + } + + int BottomUpPropagation() { + const CallNode* last_copy_node = nullptr; + int cur_dev_type = -1; + int out_dev_type = -1; + for (auto it = post_visitor_.post_dfs_order_.crbegin(); + it != post_visitor_.post_dfs_order_.crend(); ++it) { + if (const auto* node = GetDeviceCopyNode(it->first)) { + CHECK(node->IsInstance()); + last_copy_node = static_cast(node); + const auto* attrs = last_copy_node->attrs.as(); + cur_dev_type = attrs->src_dev_type; + if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type; + if (it->second) device_map_.Set(GetRef(it->first), attrs->dst_dev_type); + } else if (last_copy_node) { + Expr expr = GetRef(it->first); + CHECK_EQ(device_map_.count(expr), 0U); + if (it->second) device_map_.Set(expr, cur_dev_type); } } + return out_dev_type; + } + + void FillPropagation(int out_dev_type) { + for (const auto& it : post_visitor_.post_dfs_order_) { + Expr expr = GetRef(it.first); + if (!it.second) device_map_.Set(expr, out_dev_type); + } } PostDfsOrderVisitor post_visitor_; @@ -534,9 +541,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) { } } -Map CollectDeviceInfo(const Expr& expr) { - return DeviceInfo::GetDeviceMap(expr); -} +Map CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); } Map CollectDeviceAnnotationOps(const Expr& expr) { return AnnotatationVisitor::GetAnnotations(expr); diff --git a/src/tir/transforms/lower_tvm_builtin.cc b/src/tir/transforms/lower_tvm_builtin.cc index 2991a0842b7a..8aab9f877c29 100644 --- a/src/tir/transforms/lower_tvm_builtin.cc +++ b/src/tir/transforms/lower_tvm_builtin.cc @@ -87,19 +87,6 @@ class BuiltinLower : public StmtExprMutator { op = stmt.as(); // Get constant allocation bound. int64_t nbytes = GetVectorBytes(op->dtype); - // NOTE(zhanghao): remove special handling for kDLCPU - // otherwise, may cause LLVM parameters match error - // if in heterogenous targets - // if (device_type_.defined()) { - // if (arith::GetConst(device_type_, &dev_type)) { - // if (dev_type == kDLCPU) { - // int32_t constant_size = op->constant_allocation_size(); - // if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) { - // return stmt; - // } - // } - // } - // } PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes); for (size_t i = 0; i < op->extents.size(); ++i) { total_bytes = total_bytes * op->extents[i]; diff --git a/vta/python/vta/top/graphpack.py b/vta/python/vta/top/graphpack.py index 255724afc809..633ef3f60c9b 100644 --- a/vta/python/vta/top/graphpack.py +++ b/vta/python/vta/top/graphpack.py @@ -93,7 +93,7 @@ def _weight_shape_match_transpose(data, dshape, channels, cfactor_out): if pad_width != 0: pad_width = cfactor_out - pad_width data = op.nn.pad(data, [[0, 0], [0, pad_width], [0, 0], [0, 0]]) - dshape = tuple([dshape[0]] + [dshape[1] + pad_width, dshape[2], dshape[3]]) + dshape = tuple(dshape[0], [dshape[1] + pad_width, dshape[2], dshape[3]]) if channels_pad != 0: channels = channels + (cfactor_out - channels_pad) @@ -174,104 +174,6 @@ def _operator_idx_inc(expr, count_meta, operator_current_idx): operator_current_idx = operator_current_idx + 1 return operator_current_idx - -class ExprDeviceAnnot(ExprMutator): - """Visitor to perform graph annotation on an AST. - - Parameters - ---------- - start: int - the start location to mark run on vta (inclusive) - end: int - the end location to mark run on vta (exclusive) - - Returns - --------- - None - """ - def __init__(self, start=-1, end=-1): - self.ext_ctx = tvm.context("ext_dev") - self.cpu_ctx = tvm.context("cpu") - self.cast = op.op.get("cast") - self.counter = -1 - self.start = start - self.end = end - super().__init__() - - def visit_call(self, call): - """ Visit the children. """ - # First visit the children. - oshape = _get_tensor_shape(call) - odtype = _get_tensor_type(call) - input_types = [arg.checked_type for arg in call.args] - args = [self.visit(arg) for arg in call.args] - - self.counter += 1 - if self.counter == self.start: - ret = relay.Call(call.op, args, call.attrs) - ret = relay.annotation.on_device(ret, self.ext_ctx) - return ret - elif self.counter == self.end: - ret = relay.Call(call.op, args, call.attrs) - ret = relay.annotation.on_device(ret, self.cpu_ctx) - return ret - elif self.counter > self.start and self.counter < self.end: - ret = relay.Call(call.op, args, call.attrs) - - # skip the float op, i.e., float->int cast - if self.is_float_op(call): - return ret - - return relay.annotation.on_device(ret, self.ext_ctx) - - return relay.Call(self.visit(call.op), args, call.attrs) - - def is_float_op(self, call): - """check if this op belongs to a float op - in general, float op's odtype is float; - a special case is float->int cast, which follow this op sequence: - multiply(float) -> round(float) -> clip(float) -> cast(int); - """ - args = call.args - odtype = _get_tensor_type(call) - op = call.op - - if odtype == "float32": - return True - elif op == self.cast: - idtype = _get_tensor_type(args[0]) - if idtype == "float32": - return True - - return False - - -class ExprLocater(ExprMutator): - """Visitor to locate op on an AST. - """ - def __init__(self): - self.counter = -1 - self.op2nodes = {} - super().__init__() - - def visit_call(self, call): - """ Visit the children. """ - # First visit the children. - args = [self.visit(arg) for arg in call.args] - - odtype = _get_tensor_type(call) - self.counter += 1 - if (call.op, odtype) in self.op2nodes: - self.op2nodes[(call.op, odtype)].append(self.counter) - else: - self.op2nodes[(call.op, odtype)] = [self.counter] - - return relay.Call( - self.visit(call.op), - args, - call.attrs) - - class ExprPack(ExprMutator): """Visitor to perform graph packing on an AST. """ @@ -415,7 +317,7 @@ def visit_call(self, call): elif self.start_pack and call.op == op.op.get('cast') and \ input_types[0].dtype == 'int32': cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs) - return cast + return relay.Call(op.op.get('copy'), [cast]) elif call.op == self.pad: pad_width = call.attrs.pad_width if len(pad_width) == 6: @@ -510,10 +412,7 @@ def graph_pack(expr, stop_name="nn.global_avg_pool2d", start_name_idx=None, stop_name_idx=None, - count_meta=False, - device_annot=False, - annot_start_name="nn.conv2d", - annot_end_name="annotation.stop_fusion"): + count_meta=False): """Pack the graph into batch&channel packed format. Parameters @@ -550,23 +449,13 @@ def graph_pack(expr, 'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase logic would count the meta. - device_annot: boolean, optional - if we want to annoate the device_type - - annot_start_name: str, optional - device annotation start node, from which we mark the nodes as `ext_dev` - - annot_end_name: str, optional - device annotation end node, after which we mark the nodes as 'cpu' - Returns ------- expr : Expr The transformed expression. """ assert isinstance(expr, relay.Function) - assert ((start_name != stop_name) or (start_name_idx is None != stop_name_idx is None) or \ - (not (start_name_idx is None and stop_name_idx is None)) or (start_name_idx < stop_name_idx)) + assert ((start_name != stop_name) or (start_name_idx < stop_name_idx)) expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta) expr = run_opt_pass(expr, transform.InferType()) packer = ExprPack( @@ -574,23 +463,4 @@ def graph_pack(expr, weight_bits) expr = packer.visit(expr) assert not packer.start_pack - expr = run_opt_pass(expr, transform.InferType()) - - if device_annot: - expr_locator = ExprLocater() - expr_locator.visit(expr) - - annot_start = op.op.get(annot_start_name) - start = expr_locator.op2nodes[(annot_start, "int32")][0] - - annot_end = op.op.get(annot_end_name) - # we mark the next op to the last stop_fusion on cpu device - end = expr_locator.op2nodes[(annot_end, "int8")][-1] + 1 - - device_annot = ExprDeviceAnnot(start=start, end=end) - expr = device_annot.visit(expr) - ret = run_opt_pass(expr, transform.InferType()) - - return ret - else: - return expr + return run_opt_pass(expr, transform.InferType()) diff --git a/vta/runtime/runtime.cc b/vta/runtime/runtime.cc index 8e9cdc678b59..23024571d625 100644 --- a/vta/runtime/runtime.cc +++ b/vta/runtime/runtime.cc @@ -526,6 +526,7 @@ class UopQueue : public BaseQueue { kernel->sram_begin_ = 0; kernel->sram_end_ = 0; } + cache_.clear(); cache_idx_ = 0; BaseQueue::Reset(); diff --git a/vta/tutorials/autotvm/tune_alu_vta.py b/vta/tutorials/autotvm/tune_alu_vta.py index cf4922450ce5..661a3309a5c7 100644 --- a/vta/tutorials/autotvm/tune_alu_vta.py +++ b/vta/tutorials/autotvm/tune_alu_vta.py @@ -42,7 +42,7 @@ # Compile network # --------------- # Perform vta-specific compilation with Relay from a Gluon model -def compile_network(env, target, model, start_pack, stop_pack, device_annot=False): +def compile_network(env, target, model, start_pack, stop_pack): # Populate the shape and data type dictionary dtype_dict = {"data": 'float32'} @@ -70,8 +70,7 @@ def compile_network(env, target, model, start_pack, stop_pack, device_annot=Fals env.BLOCK_OUT, env.WGT_WIDTH, start_name=start_pack, - stop_name=stop_pack, - device_annot=device_annot) + stop_name=stop_pack) return relay_prog, params