Skip to content

Commit 91d25c1

Browse files
committed
fix bugs
1 parent da924be commit 91d25c1

File tree

18 files changed

+68
-91
lines changed

18 files changed

+68
-91
lines changed

python/tvm/autotvm/graph_tuner/utils/traverse_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def _traverse_expr(node):
8383
return
8484
node_index = len(node_list)
8585
node_entry = {"node": node, "inputs": [], "types": [],
86-
"op": "null", "name": None}
86+
"op": None, "name": None}
8787

8888
if isinstance(node, Call):
8989
op = node.op

python/tvm/autotvm/graph_tuner/utils/utils.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def has_multiple_inputs(node_list, node_idx, input_names):
4646
in_idx = in_idx[0]
4747
in_node = node_list[in_idx]
4848
# Exclude parameter nodes
49-
if in_node["op"] != "null" or \
49+
if in_node["op"] is not None or \
5050
("name" in in_node and in_node["name"] in input_names):
5151
num_inputs += 1
5252
return num_inputs > 1
@@ -71,9 +71,10 @@ def is_boundary_node(node_entry, input_names):
7171
whether node is a boundary node.
7272
"""
7373
# Operators dependent on original layouts.
74-
_LAYOUT_FIXED_OP = ["batch_flatten", "transpose", "reshape",
75-
"multibox_prior", "multibox_transform_loc", "where",
76-
"non_max_suppression", "strided_slice"]
74+
_LAYOUT_FIXED_OP = [relay.op.get(name) for name in (
75+
"nn.batch_flatten", "transpose", "reshape", "vision.multibox_prior",
76+
"vision.multibox_transform_loc", "where", "vision.non_max_suppression",
77+
"strided_slice")]
7778

7879
out = node_entry["op"] in _LAYOUT_FIXED_OP or \
7980
("name" in node_entry and node_entry["name"] in input_names)
@@ -94,9 +95,7 @@ def is_skipped_node(node_entry):
9495
whether node is skipped.
9596
"""
9697
# Operators not counted in graph tuner.
97-
_SKIPPED_OP = ["Tuple"]
98-
99-
return node_entry["op"] in _SKIPPED_OP
98+
return isinstance(node_entry["node"], relay.Tuple)
10099

101100

102101
def bind_inputs(expr, input_shapes=None, input_dtypes="float32"):

python/tvm/autotvm/record.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,12 +161,12 @@ def clean_json_to_python(x):
161161
tgt = _target.create(items[0])
162162
task_tuple = pickle.loads(base64.b64decode(items[1].encode()))
163163
config = pickle.loads(base64.b64decode(items[2].encode()))
164-
result = pickle.loads(base64.b64decode(items[3].encode()))
164+
result = MeasureResult(*pickle.loads(base64.b64decode(items[3].encode())))
165165
config.cost = np.mean(result.costs)
166166

167167
tsk = task.Task(task_tuple[0], task_tuple[1])
168168
tsk.workload = task_tuple[3]
169-
return MeasureInput(tgt, tsk, config), MeasureResult(*result)
169+
return MeasureInput(tgt, tsk, config), result
170170

171171
raise RuntimeError("Invalid log protocol: " + protocol)
172172

python/tvm/autotvm/task/dispatcher.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@
3333
import logging
3434

3535
import numpy as np
36-
from decorator import decorate
37-
38-
from tvm import target as _target
3936

4037
from .space import FallbackConfigEntity
4138

python/tvm/autotvm/task/task.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,8 @@ def get_config():
418418
cfg: ConfigSpace or ConfigEntity
419419
The current config
420420
"""
421-
return DispatchContext.current.query(None, None)
421+
tgt = _target.current_target(allow_none=True)
422+
return DispatchContext.current.query(tgt, None)
422423

423424
class FlopCalculationError(RuntimeError):
424425
"""Error happens when estimating FLOP for a compute op"""

python/tvm/relay/backend/compile_engine.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ def create_tensors(typ, tensors):
256256
self.func_name = "fused"
257257
outputs = self.visit(prim_func.body)
258258
if len(self.func_name) > ScheduleGetter.MAX_FUNC_NAME_LENGTH:
259-
hash_digest = int(hashlib.sha1(self.func_name).hexdigest(), 16)
259+
hash_digest = int(hashlib.sha1(self.func_name.encode("utf-8")).hexdigest(), 16)
260260
self.func_name = "%s_%s" % (
261261
self.func_name[:ScheduleGetter.MAX_FUNC_NAME_LENGTH], hash_digest)
262262

@@ -270,7 +270,8 @@ def create_tensors(typ, tensors):
270270
# print('master op:', self.master_op.name)
271271
sch = self.master_implement.schedule(self.master_attrs, tensor_outs, self.target)
272272
for scalar in self.scalars:
273-
sch[scalar].compute_inline()
273+
if scalar in sch.stage_map:
274+
sch[scalar].compute_inline()
274275
return CachedFunc(self.target, self.func_name, inputs, outputs, sch)
275276

276277
def visit_var(self, var):
@@ -381,10 +382,10 @@ def visit_tuple(self, tup):
381382
return fields
382383

383384
def visit_tuple_getitem(self, t):
384-
tup = self.visit(t.tuple)
385-
assert len(tup) == len(t.tuple.checked_type.fields)
385+
tup = self.visit(t.tuple_value)
386+
assert len(tup) == len(t.tuple_value.checked_type.fields)
386387
assert t.index >= 0
387-
assert t.index < tup.size()
388+
assert t.index < len(tup)
388389
return [tup[t.index]]
389390

390391

python/tvm/relay/frontend/tensorflow.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ def _impl(inputs, attr, params):
310310
flip_layout = True
311311

312312
if attr['data_format'] == 'NHWC':
313+
in_channels = input_shape[3]
313314
kernel_h, kernel_w, _, depth_mult = weights_shape
314315
attr['kernel_shape'] = (weights_shape[0], weights_shape[1])
315316
if opname == 'conv':
@@ -323,6 +324,7 @@ def _impl(inputs, attr, params):
323324
attr['dilations'] = (attr['dilations'][1], attr['dilations'][2])
324325
attr['strides'] = (attr['strides'][1], attr['strides'][2])
325326
elif attr['data_format'] == 'NCHW':
327+
in_channels = input_shape[1]
326328
_, depth_mult, kernel_h, kernel_w = weights_shape
327329
attr['kernel_shape'] = (weights_shape[2], weights_shape[3])
328330
if opname == 'conv':
@@ -343,7 +345,7 @@ def _impl(inputs, attr, params):
343345
raise tvm.error.OpAttributeInvalid(msg.format(attr['data_format']))
344346

345347
if opname == 'depthwise':
346-
attr['groups'] = attr['channels']
348+
attr['groups'] = in_channels
347349

348350
# Fix padding
349351
attr['padding'] = attr['padding'].decode("utf-8")

python/tvm/relay/op/nn/_nn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,7 @@ def compute_upsampling3d(attrs, inputs, out_dtype):
307307

308308
# mirror_pad
309309
@reg.register_compute("nn.mirror_pad")
310-
def compute_mirror_pad(attrs, inputs, out_dtype, target):
310+
def compute_mirror_pad(attrs, inputs, out_dtype):
311311
pad_before, pad_after = list(zip(*attrs.pad_width))
312312
mode = attrs.mode
313313
out = topi.nn.mirror_pad(inputs[0], pad_before=pad_before, pad_after=pad_after, mode=mode)

python/tvm/relay/op/strategy/x86.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,8 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target):
116116
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw))
117117
elif layout == "NHWC":
118118
assert kernel_layout == "HWOI"
119-
logger.warning("For x86 target, NCHW layout is recommended for depthwise_conv2d.")
119+
logger.warning("For x86 target, depthwise_conv2d with NCHW layout is "
120+
"not optimized.")
120121
strategy.add_implement(
121122
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
122123
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc))

src/relay/op/nn/convolution.h

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,16 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
153153
<< " But got " << out_layout;
154154

155155
Array<IndexExpr> dshape_nchw = trans_in_layout.ForwardShape(data->shape);
156+
bool is_depthwise = false;
157+
if (param->groups > 1) {
158+
CHECK(weight->shape.defined()) << "Weight shape must be specified " <<
159+
"when groups is greater than 1.";
160+
Array<IndexExpr> wshape_oihw = trans_kernel_layout.ForwardShape(weight->shape);
161+
if (tvm::tir::Equal(param->groups, dshape_nchw[1]) &&
162+
tvm::tir::Equal(param->groups, wshape_oihw[0])) {
163+
is_depthwise = true;
164+
}
165+
}
156166

157167
IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
158168
// infer weight if the kernel_size and channels are defined
@@ -161,9 +171,9 @@ bool Conv2DRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
161171
CHECK_EQ(param->dilation.size(), 2);
162172
Array<IndexExpr> wshape;
163173

164-
if (tvm::tir::Equal(param->channels, param->groups) && !tvm::tir::Equal(param->channels, 1)) {
174+
if (is_depthwise) {
165175
// infer weight's shape for depthwise convolution
166-
wshape = {{dshape_nchw[1], indexdiv(param->groups, dshape_nchw[1]), param->kernel_size[0],
176+
wshape = {{dshape_nchw[1], indexdiv(param->channels, dshape_nchw[1]), param->kernel_size[0],
167177
param->kernel_size[1]}};
168178
} else {
169179
wshape = {{param->channels, indexdiv(dshape_nchw[1], param->groups), param->kernel_size[0],

0 commit comments

Comments
 (0)