Skip to content

Commit 81f9d5b

Browse files
ZihengJiangtqchen
authored andcommitted
[TOP] Initial Schedule of MobileNet on Rasp (#496)
* [TOP] Initial Schedule of MobileNet on Rasp * Fix * Fix
1 parent 5b8a8d0 commit 81f9d5b

File tree

4 files changed

+88
-1
lines changed

4 files changed

+88
-1
lines changed

topi/python/topi/nn/conv2d.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
Im2ColPack = namedtuple('Im2ColPack',
2222
['vp', 'vq', 'ba', 'bc', 'unroll'])
2323

24-
# workloads of resnet18 on imagenet
2524
_WORKLOADS = [
25+
# workloads of resnet18 on imagenet
2626
Workload(224, 224, 3, 64, 7, 7, 3, 3, 2, 2),
2727
Workload(56, 56, 64, 64, 3, 3, 1, 1, 1, 1),
2828
Workload(56, 56, 64, 64, 1, 1, 0, 0, 1, 1),
@@ -35,6 +35,17 @@
3535
Workload(14, 14, 256, 512, 3, 3, 1, 1, 2, 2),
3636
Workload(14, 14, 256, 512, 1, 1, 0, 0, 2, 2),
3737
Workload(7, 7, 512, 512, 3, 3, 1, 1, 1, 1),
38+
# workloads of mobile net on imagenet
39+
Workload(224, 224, 3, 32, 3, 3, 1, 1, 2, 2),
40+
Workload(112, 112, 32, 64, 1, 1, 0, 0, 1, 1),
41+
Workload(56, 56, 64, 128, 1, 1, 0, 0, 1, 1),
42+
Workload(56, 56, 128, 128, 1, 1, 0, 0, 1, 1),
43+
Workload(28, 28, 128, 256, 1, 1, 0, 0, 1, 1),
44+
Workload(28, 28, 256, 256, 1, 1, 0, 0, 1, 1),
45+
Workload(14, 14, 256, 512, 1, 1, 0, 0, 1, 1),
46+
Workload(14, 14, 512, 512, 1, 1, 0, 0, 1, 1),
47+
Workload(7, 7, 512, 1024, 1, 1, 0, 0, 1, 1),
48+
Workload(7, 7, 1024, 1024, 1, 1, 0, 0, 1, 1),
3849
]
3950

4051
# platform specific schedule

topi/python/topi/rasp/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,4 @@
33
from __future__ import absolute_import as _abs
44

55
from .conv2d import *
6+
from .depthwise_conv2d import *

topi/python/topi/rasp/conv2d.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,17 @@
2323
Im2ColPack(7, 4, 1, 16, True),
2424
Im2ColPack(7, 4, 1, 8, False),
2525
Im2ColPack(7, 4, 1, 16, False),
26+
27+
SpatialPack(2, 2, 4, 28, 1, True),
28+
SpatialPack(1, 4, 8, 14, 1, False),
29+
SpatialPack(1, 2, 16, 8, 1, True),
30+
SpatialPack(1, 4, 8, 8, 8, True),
31+
SpatialPack(2, 2, 8, 1, 1, False),
32+
SpatialPack(1, 4, 8, 4, 8, False),
33+
SpatialPack(2, 2, 8, 1, 4, False),
34+
SpatialPack(2, 2, 8, 1, 8, False),
35+
SpatialPack(1, 1, 16, 1, 4, False),
36+
SpatialPack(1, 1, 4, 1, 4, True),
2637
]
2738

2839
def _schedule_conv2d(wkl):
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# pylint: disable=invalid-name,unused-variable
2+
"""Schedule for depthwise_conv2d with auto fusion"""
3+
import tvm
4+
from .. import tag
5+
6+
def _schedule(s, data, data_pad, kernel, output, last):
7+
A, B, C = data, kernel, output
8+
A0 = data_pad
9+
C0 = last
10+
11+
_, c, h, w = s[C].op.axis
12+
dh, dw = s[C].op.reduce_axis
13+
14+
oh, ow, ih, iw = s[C].tile(h, w, 2, 4)
15+
s[C].reorder(oh, ow, dh, dw, ih, iw)
16+
s[C].unroll(ih)
17+
s[C].vectorize(iw)
18+
19+
s[C].parallel(c)
20+
s[C].pragma(c, "parallel_launch_point")
21+
s[C].pragma(c, "parallel_stride_pattern")
22+
s[C].pragma(c, "parallel_barrier_when_finish")
23+
return s
24+
25+
26+
27+
def schedule_depthwise_conv2d(outs):
28+
"""Schedule for depthwise_conv2d nchw forward.
29+
30+
Parameters
31+
----------
32+
outs: Array of Tensor
33+
The computation graph description of depthwise_conv2d
34+
in the format of an array of tensors.
35+
36+
Returns
37+
-------
38+
s: Schedule
39+
The computation schedule for depthwise_conv2d nchw.
40+
"""
41+
outs = [outs] if isinstance(outs, tvm.tensor.Tensor) else outs
42+
s = tvm.create_schedule([x.op for x in outs])
43+
44+
def traverse(op):
45+
# inline all one-to-one-mapping operators except the last stage (output)
46+
if tag.is_broadcast(op.tag):
47+
if op not in s.outputs:
48+
s[op].compute_inline()
49+
for tensor in op.input_tensors:
50+
if tensor.op.input_tensors:
51+
traverse(tensor.op)
52+
# schedule depthwise_conv2d
53+
if op.tag == 'depthwise_conv2d_nchw':
54+
output = op.output(0)
55+
kernel = op.input_tensors[1]
56+
data = op.input_tensors[0]
57+
data_pad = None
58+
if isinstance(data.op, tvm.tensor.ComputeOp) and "pad" in data.op.tag:
59+
data_pad = data
60+
data = data_pad.op.input_tensors[0]
61+
_schedule(s, data, data_pad, kernel, output, outs[0])
62+
63+
traverse(outs[0].op)
64+
return s

0 commit comments

Comments
 (0)