Skip to content

Commit bea1dcb

Browse files
committed
[NNVM, TOPI] Bug fixes (apache#24)
* bug fix * passing the parameters when building the nnvm graph before extracting the tasks in autotvm * bug fix for operator fusion * fixing integration and tuning scripts
1 parent c9e4def commit bea1dcb

File tree

6 files changed

+26
-22
lines changed

6 files changed

+26
-22
lines changed

nnvm/python/nnvm/top/nn.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -98,25 +98,25 @@ def compute_conv2d(attrs, inputs, _):
9898
if groups == 1 and layout == 'NCHW4c' and inputs[0].dtype == 'int8':
9999
# pylint: disable=assignment-from-no-return
100100
out = topi.nn.conv2d(inputs[0], inputs[1], strides, padding,
101-
dilation, layout, out_dtype=out_dtype)
101+
dilation, layout, out_dtype)
102102
# pylint: enable=assignment-from-no-return
103103
elif groups == 1:
104104
out = topi.nn.conv2d(
105-
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype=out_dtype)
105+
inputs[0], inputs[1], strides, padding, dilation, layout, out_dtype)
106106
elif layout == "NCHW" and \
107107
groups == get_const_int(inputs[0].shape[1]) and \
108108
groups == channels:
109109
out = topi.nn.depthwise_conv2d_nchw(
110-
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
110+
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
111111
elif layout in ["NCHW", "NCHW4c"]:
112112
out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
113-
out_dtype=out_dtype)
113+
out_dtype)
114114
elif layout == "NHWC" and \
115115
kernel_layout == "HWOI" and \
116116
groups == get_const_int(inputs[0].shape[3]) and \
117117
groups == channels:
118118
out = topi.nn.depthwise_conv2d_nhwc(
119-
inputs[0], inputs[1], strides, padding, dilation, out_dtype=out_dtype)
119+
inputs[0], inputs[1], strides, padding, dilation, out_dtype)
120120
else:
121121
raise ValueError("not support arbitrary group number for now")
122122

python/tvm/autotvm/task/nnvm_integration.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
logger = logging.getLogger('autotvm')
1717

1818

19-
def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
19+
def extract_from_graph(graph, shape, dtype, target, symbols, params, target_host=None):
2020
""" Extract tuning tasks from a nnvm graph.
2121
2222
This function collects tuning tasks by building the graph and trace all the calls to topi.
@@ -33,6 +33,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
3333
The compilation target
3434
symbols : Array of nnvm.symbol
3535
Array of nnvm symbols want to be tuned
36+
params : dict of str to NDArray
37+
The parameter dictionary.
3638
target_host: tvm.target.Target
3739
The host compilation target
3840
@@ -66,7 +68,8 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
6668
# run compiler to collect all TOPI calls during compilation
6769
nnvm.compiler.engine.clear_cache()
6870
with ApplyHistoryBest([]):
69-
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype)
71+
nnvm.compiler.build(graph, target=target, shape=shape, dtype=dtype,
72+
target_host=target_host, params=params)
7073
nnvm.compiler.engine.clear_cache()
7174

7275
logger.disabled = old_state
@@ -80,7 +83,7 @@ def extract_from_graph(graph, shape, dtype, target, symbols, target_host=None):
8083
template_key='direct')
8184
tasks.append(tsk)
8285
except topi.InvalidShapeError:
83-
print("shape error")
86+
print("[Warning] invalid shape")
8487

8588
return tasks
8689

vta/python/vta/top/vta_conv2d.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ def schedule_packed_conv2d(cfg, outs,
190190
ewise_inputs = []
191191
ewise_ops = []
192192
conv2d_res = []
193-
assert output.op.input_tensors[0].dtype == "int32"
193+
assert "int" in output.op.input_tensors[0].dtype
194194

195195
def _traverse(op):
196196
if topi.tag.is_broadcast(op.tag):

vta/scripts/tune_conv.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def my_clip(x, a_min, a_max):
2020
x = tvm.compute(x.shape, lambda *i: tvm.max(x(*i), const_min), name="clipB")
2121
return x
2222

23-
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
23+
def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype):
2424
data_shape = (N//env.BATCH, CI//env.BLOCK_IN, H, W, env.BATCH, env.BLOCK_IN)
2525
kernel_shape = (CO//env.BLOCK_OUT, CI//env.BLOCK_IN, KH, KW, env.BLOCK_OUT, env.BLOCK_IN)
2626
bias_shape = (N//env.BATCH, CO//env.BLOCK_OUT, 1, 1, env.BATCH, env.BLOCK_OUT)
@@ -33,7 +33,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
3333
kernel = tvm.placeholder(kernel_shape, name="kernel", dtype=env.wgt_dtype)
3434

3535
with tvm.target.vta():
36-
res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides,
36+
res = topi.nn.conv2d(data, kernel, padding=padding, strides=strides, dilation=dilation,
3737
layout='NCHW%dn%dc' % (env.BATCH, env.BLOCK_IN), out_dtype='int32')
3838
res = topi.add(res, bias)
3939
res = topi.right_shift(res, 8)
@@ -46,13 +46,13 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
4646
s = tvm.create_schedule([res.op])
4747

4848

49-
return s, [data, kernel, bias, res]
49+
return s, [data, kernel, bias, res]
5050

5151
if __name__ == '__main__':
52-
N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype = \
53-
1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), 'int8', 'int32'
52+
N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype = \
53+
1, 64, 56, 56, 64, 3, 3, (1, 1), (1, 1), (1, 1), 'int8', 'int32'
5454

55-
task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype),
55+
task = autotvm.task.create(conv2d, args=(N, CI, H, W, CO, KH, KW, strides, padding, dilation, in_dtype, out_dtype),
5656
target=tvm.target.vta(env.MODEL), target_host=env.target_host, template_key='direct')
5757
print(task.config_space)
5858

@@ -62,7 +62,7 @@ def conv2d(N, CI, H, W, CO, KH, KW, strides, padding, in_dtype, out_dtype):
6262

6363
measure_option = autotvm.measure_option(
6464
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
65-
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190, number=4, repeat=3, timeout=30,
65+
runner=autotvm.RPCRunner(env.TARGET, '10.77.1.109', 9190, number=4, repeat=3, timeout=30,
6666
check_correctness=True))
6767

6868
tuner = autotvm.tuner.RandomTuner(task)

vta/scripts/tune_resnet.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,9 @@ def extract_tasks(sym, params, target, target_host):
8989
sym = vta.graph.pack(sym, shape_dict, env.BATCH, env.BLOCK_OUT)
9090

9191
with vta.build_config():
92-
tasks = autotvm.task.extract_from_graph(sym, target=target, target_host=target_host,
93-
shape=shape_dict, dtype=dtype_dict, symbols=(nnvm.sym.conv2d,))
92+
tasks = autotvm.task.extract_from_graph(sym, shape=shape_dict, dtype=dtype_dict, target=target,
93+
params=params, symbols=(nnvm.sym.conv2d,), target_host=target_host,
94+
)
9495
return tasks
9596

9697

@@ -169,7 +170,7 @@ def tune_tasks(tasks,
169170

170171
'measure_option': autotvm.measure_option(
171172
builder=autotvm.LocalBuilder(build_func=vta.vta_autotvm_build_func),
172-
runner=autotvm.RPCRunner(env.TARGET, 'fleet', 9190,
173+
runner=autotvm.RPCRunner(env.TARGET, '10.77.1.109', 9190,
173174
number=4, repeat=3, timeout=60,
174175
check_correctness=True))
175176
}
@@ -202,7 +203,7 @@ def tune_tasks(tasks,
202203

203204
# upload module to device
204205
print("Upload...")
205-
remote = autotvm.measure.request_remote(env.TARGET, 'fleet', 9190, timeout=10000)
206+
remote = autotvm.measure.request_remote(env.TARGET, '10.77.1.109', 9190, timeout=10000)
206207
remote.upload(tmp.relpath(filename))
207208
rlib = remote.load_module(filename)
208209

vta/tests/python/integration/test_benchmark_topi_conv2d.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ def run_cpu_conv2d(env, remote, wl, target):
4747

4848
with target:
4949
res_conv = topi.nn.conv2d(
50-
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), "NCHW", "int32")
50+
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1), "NCHW", "int32")
5151
res = topi.right_shift(res_conv, 8)
5252
res = my_clip(res, 0, 127)
5353
res = topi.cast(res, "int8")
@@ -202,7 +202,7 @@ def run_vta_conv2d(env, remote, wl, target, check_correctness=True, print_ir=Fal
202202

203203
with target:
204204
res_conv = topi.nn.conv2d(
205-
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad),
205+
data, kernel, (wl.hstride, wl.wstride), (wl.hpad, wl.wpad), (1, 1),
206206
"NCHW%dn%dc" % (env.BATCH, env.BLOCK_IN), 'int32')
207207
res = topi.right_shift(res_conv, 8)
208208
res = topi.add(res, bias)

0 commit comments

Comments
 (0)