Skip to content

Commit

Permalink
only keep opencl related code
Browse files Browse the repository at this point in the history
  • Loading branch information
zhanghaohit committed Jul 23, 2020
1 parent 335f99c commit 37692bf
Show file tree
Hide file tree
Showing 6 changed files with 51 additions and 189 deletions.
2 changes: 1 addition & 1 deletion python/tvm/autotvm/task/space.py
Original file line number Diff line number Diff line change
Expand Up @@ -779,7 +779,7 @@ def _add_new_transform(self, space_class, name, axes, policy, **kwargs):
return [Axis(None, i) for i in range(space_class.get_num_output(axes, policy, **kwargs))]

def __len__(self):
if self._length is None or self._length <= 1:
if self._length is None:
self._length = int(np.prod([len(x) for x in self.space_map.values()]))
return self._length

Expand Down
79 changes: 42 additions & 37 deletions src/relay/transforms/device_annotation.cc
Original file line number Diff line number Diff line change
Expand Up @@ -386,38 +386,22 @@ class DeviceInfo {
}

void VisitExpr_(const ConstantNode* cn) final {
device_tag_[cn] = dev_type_;
post_dfs_order_.push_back(std::make_pair(cn, has_copy_));
}

void VisitExpr_(const CallNode* call) final {
// Skip annotation nodes.
if (!IsOnDeviceNode(call)) {
if (const auto* node = GetDeviceCopyNode(call)) {
CHECK(node->IsInstance<CallNode>());
const auto* call_node = static_cast<const CallNode*>(node);
auto attrs = call_node->attrs.as<DeviceCopyAttrs>();

if (GetDeviceCopyNode(call)) {
num_device_copy_ops_++;
bool has_copy_prev = has_copy_;
has_copy_ = true;
dev_type_ = attrs->src_dev_type;
for (auto& arg : call->args) {
Visit(arg);
// restore the type for remaining arguments
dev_type_ = attrs->src_dev_type;
}
device_tag_[call] = attrs->dst_dev_type;
// update the out_dev_type_, which should be the dst_dev_type of last copy
out_dev_type_ = attrs->dst_dev_type;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
has_copy_ = has_copy_prev;
} else {
for (auto& arg : call->args) {
int cur_dev_type = dev_type_;
Visit(arg);
// restore the type for remaining arguments
dev_type_ = cur_dev_type;
}
device_tag_[call] = dev_type_;
ExprVisitor::VisitExpr_(call);
post_dfs_order_.push_back(std::make_pair(call, has_copy_));
}
}
}
Expand All @@ -430,24 +414,22 @@ class DeviceInfo {
void VisitExpr_(const TupleGetItemNode* op) final { ExprVisitor::VisitExpr_(op); }

void VisitExpr_(const VarNode* vn) final {
device_tag_[vn] = dev_type_;
post_dfs_order_.push_back(std::make_pair(vn, has_copy_));
}

void VisitExpr_(const LetNode* ln) final {
ExprVisitor::VisitExpr_(ln);
device_tag_[ln] = dev_type_;
post_dfs_order_.push_back(std::make_pair(ln, has_copy_));
}

void VisitExpr_(const IfNode* in) final {
ExprVisitor::VisitExpr_(in);
device_tag_[in] = dev_type_;
post_dfs_order_.push_back(std::make_pair(in, has_copy_));
}

int num_device_copy_ops_{0};
bool has_copy_ = false;
int dev_type_ = -1;
int out_dev_type_ = -1;
std::unordered_map<const ExprNode*, int> device_tag_;
std::vector<std::pair<const ExprNode*, bool>> post_dfs_order_;
friend DeviceInfo;
};

Expand All @@ -473,14 +455,39 @@ class DeviceInfo {
}

void PropagateDeviceId() {
int out_dev_type = post_visitor_.out_dev_type_;
for (auto& it : post_visitor_.device_tag_) {
if (it.second != -1) {
device_map_.Set(GetRef<Expr>(it.first), it.second);
} else {
device_map_.Set(GetRef<Expr>(it.first), out_dev_type);
// Bottom-up propagation.
int out_dev_type = BottomUpPropagation();
// propagation for remained nodes.
FillPropagation(out_dev_type);
}

int BottomUpPropagation() {
const CallNode* last_copy_node = nullptr;
int cur_dev_type = -1;
int out_dev_type = -1;
for (auto it = post_visitor_.post_dfs_order_.crbegin();
it != post_visitor_.post_dfs_order_.crend(); ++it) {
if (const auto* node = GetDeviceCopyNode(it->first)) {
CHECK(node->IsInstance<CallNode>());
last_copy_node = static_cast<const CallNode*>(node);
const auto* attrs = last_copy_node->attrs.as<DeviceCopyAttrs>();
cur_dev_type = attrs->src_dev_type;
if (out_dev_type == -1) out_dev_type = attrs->dst_dev_type;
if (it->second) device_map_.Set(GetRef<Expr>(it->first), attrs->dst_dev_type);
} else if (last_copy_node) {
Expr expr = GetRef<Expr>(it->first);
CHECK_EQ(device_map_.count(expr), 0U);
if (it->second) device_map_.Set(expr, cur_dev_type);
}
}
return out_dev_type;
}

void FillPropagation(int out_dev_type) {
for (const auto& it : post_visitor_.post_dfs_order_) {
Expr expr = GetRef<Expr>(it.first);
if (!it.second) device_map_.Set(expr, out_dev_type);
}
}

PostDfsOrderVisitor post_visitor_;
Expand Down Expand Up @@ -534,9 +541,7 @@ Expr RewriteAnnotatedOps(const Expr& expr, int fallback_device) {
}
}

Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) {
return DeviceInfo::GetDeviceMap(expr);
}
Map<Expr, Integer> CollectDeviceInfo(const Expr& expr) { return DeviceInfo::GetDeviceMap(expr); }

Map<Expr, Integer> CollectDeviceAnnotationOps(const Expr& expr) {
return AnnotatationVisitor::GetAnnotations(expr);
Expand Down
13 changes: 0 additions & 13 deletions src/tir/transforms/lower_tvm_builtin.cc
Original file line number Diff line number Diff line change
Expand Up @@ -87,19 +87,6 @@ class BuiltinLower : public StmtExprMutator {
op = stmt.as<AllocateNode>();
// Get constant allocation bound.
int64_t nbytes = GetVectorBytes(op->dtype);
// NOTE(zhanghao): remove special handling for kDLCPU
// otherwise, may cause LLVM parameters match error
// if in heterogenous targets
// if (device_type_.defined()) {
// if (arith::GetConst(device_type_, &dev_type)) {
// if (dev_type == kDLCPU) {
// int32_t constant_size = op->constant_allocation_size();
// if (constant_size > 0 && constant_size * nbytes < runtime::kMaxStackAlloca) {
// return stmt;
// }
// }
// }
// }
PrimExpr total_bytes = make_const(op->extents[0].dtype(), nbytes);
for (size_t i = 0; i < op->extents.size(); ++i) {
total_bytes = total_bytes * op->extents[i];
Expand Down
140 changes: 5 additions & 135 deletions vta/python/vta/top/graphpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def _weight_shape_match_transpose(data, dshape, channels, cfactor_out):
if pad_width != 0:
pad_width = cfactor_out - pad_width
data = op.nn.pad(data, [[0, 0], [0, pad_width], [0, 0], [0, 0]])
dshape = tuple([dshape[0]] + [dshape[1] + pad_width, dshape[2], dshape[3]])
dshape = tuple(dshape[0], [dshape[1] + pad_width, dshape[2], dshape[3]])

if channels_pad != 0:
channels = channels + (cfactor_out - channels_pad)
Expand Down Expand Up @@ -174,104 +174,6 @@ def _operator_idx_inc(expr, count_meta, operator_current_idx):
operator_current_idx = operator_current_idx + 1
return operator_current_idx


class ExprDeviceAnnot(ExprMutator):
"""Visitor to perform graph annotation on an AST.
Parameters
----------
start: int
the start location to mark run on vta (inclusive)
end: int
the end location to mark run on vta (exclusive)
Returns
---------
None
"""
def __init__(self, start=-1, end=-1):
self.ext_ctx = tvm.context("ext_dev")
self.cpu_ctx = tvm.context("cpu")
self.cast = op.op.get("cast")
self.counter = -1
self.start = start
self.end = end
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
oshape = _get_tensor_shape(call)
odtype = _get_tensor_type(call)
input_types = [arg.checked_type for arg in call.args]
args = [self.visit(arg) for arg in call.args]

self.counter += 1
if self.counter == self.start:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.ext_ctx)
return ret
elif self.counter == self.end:
ret = relay.Call(call.op, args, call.attrs)
ret = relay.annotation.on_device(ret, self.cpu_ctx)
return ret
elif self.counter > self.start and self.counter < self.end:
ret = relay.Call(call.op, args, call.attrs)

# skip the float op, i.e., float->int cast
if self.is_float_op(call):
return ret

return relay.annotation.on_device(ret, self.ext_ctx)

return relay.Call(self.visit(call.op), args, call.attrs)

def is_float_op(self, call):
"""check if this op belongs to a float op
in general, float op's odtype is float;
a special case is float->int cast, which follow this op sequence:
multiply(float) -> round(float) -> clip(float) -> cast(int);
"""
args = call.args
odtype = _get_tensor_type(call)
op = call.op

if odtype == "float32":
return True
elif op == self.cast:
idtype = _get_tensor_type(args[0])
if idtype == "float32":
return True

return False


class ExprLocater(ExprMutator):
"""Visitor to locate op on an AST.
"""
def __init__(self):
self.counter = -1
self.op2nodes = {}
super().__init__()

def visit_call(self, call):
""" Visit the children. """
# First visit the children.
args = [self.visit(arg) for arg in call.args]

odtype = _get_tensor_type(call)
self.counter += 1
if (call.op, odtype) in self.op2nodes:
self.op2nodes[(call.op, odtype)].append(self.counter)
else:
self.op2nodes[(call.op, odtype)] = [self.counter]

return relay.Call(
self.visit(call.op),
args,
call.attrs)


class ExprPack(ExprMutator):
"""Visitor to perform graph packing on an AST.
"""
Expand Down Expand Up @@ -415,7 +317,7 @@ def visit_call(self, call):
elif self.start_pack and call.op == op.op.get('cast') and \
input_types[0].dtype == 'int32':
cast = relay.Call(op.op.get('cast'), [args[0]], call.attrs)
return cast
return relay.Call(op.op.get('copy'), [cast])
elif call.op == self.pad:
pad_width = call.attrs.pad_width
if len(pad_width) == 6:
Expand Down Expand Up @@ -510,10 +412,7 @@ def graph_pack(expr,
stop_name="nn.global_avg_pool2d",
start_name_idx=None,
stop_name_idx=None,
count_meta=False,
device_annot=False,
annot_start_name="nn.conv2d",
annot_end_name="annotation.stop_fusion"):
count_meta=False):
"""Pack the graph into batch&channel packed format.
Parameters
Expand Down Expand Up @@ -550,47 +449,18 @@ def graph_pack(expr,
'expr.astext(show_meta_data=False)'. When count_meta is True, the operator increase
logic would count the meta.
device_annot: boolean, optional
if we want to annoate the device_type
annot_start_name: str, optional
device annotation start node, from which we mark the nodes as `ext_dev`
annot_end_name: str, optional
device annotation end node, after which we mark the nodes as 'cpu'
Returns
-------
expr : Expr
The transformed expression.
"""
assert isinstance(expr, relay.Function)
assert ((start_name != stop_name) or (start_name_idx is None != stop_name_idx is None) or \
(not (start_name_idx is None and stop_name_idx is None)) or (start_name_idx < stop_name_idx))
assert ((start_name != stop_name) or (start_name_idx < stop_name_idx))
expr = get_subgraph(expr, start_name, stop_name, start_name_idx, stop_name_idx, count_meta)
expr = run_opt_pass(expr, transform.InferType())
packer = ExprPack(
bfactor, cfactor,
weight_bits)
expr = packer.visit(expr)
assert not packer.start_pack
expr = run_opt_pass(expr, transform.InferType())

if device_annot:
expr_locator = ExprLocater()
expr_locator.visit(expr)

annot_start = op.op.get(annot_start_name)
start = expr_locator.op2nodes[(annot_start, "int32")][0]

annot_end = op.op.get(annot_end_name)
# we mark the next op to the last stop_fusion on cpu device
end = expr_locator.op2nodes[(annot_end, "int8")][-1] + 1

device_annot = ExprDeviceAnnot(start=start, end=end)
expr = device_annot.visit(expr)
ret = run_opt_pass(expr, transform.InferType())

return ret
else:
return expr
return run_opt_pass(expr, transform.InferType())
1 change: 1 addition & 0 deletions vta/runtime/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,7 @@ class UopQueue : public BaseQueue<VTAUop> {
kernel->sram_begin_ = 0;
kernel->sram_end_ = 0;
}

cache_.clear();
cache_idx_ = 0;
BaseQueue<VTAUop>::Reset();
Expand Down
5 changes: 2 additions & 3 deletions vta/tutorials/autotvm/tune_alu_vta.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
# Compile network
# ---------------
# Perform vta-specific compilation with Relay from a Gluon model
def compile_network(env, target, model, start_pack, stop_pack, device_annot=False):
def compile_network(env, target, model, start_pack, stop_pack):

# Populate the shape and data type dictionary
dtype_dict = {"data": 'float32'}
Expand Down Expand Up @@ -70,8 +70,7 @@ def compile_network(env, target, model, start_pack, stop_pack, device_annot=Fals
env.BLOCK_OUT,
env.WGT_WIDTH,
start_name=start_pack,
stop_name=stop_pack,
device_annot=device_annot)
stop_name=stop_pack)

return relay_prog, params

Expand Down

0 comments on commit 37692bf

Please sign in to comment.