Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 129 additions & 7 deletions topi/python/topi/rasp/depthwise_conv2d.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,147 @@
# pylint: disable=invalid-name,unused-variable
"""Schedule for depthwise_conv2d with auto fusion"""
from __future__ import absolute_import as _abs
from collections import namedtuple
import tvm
from .. import tag
from ..nn.util import infer_pad, infer_stride, get_pad_tuple


_Workload = namedtuple('Workload',
['height', 'width', 'channel', 'multiplier',
'hkernel', 'wkernel', 'hpad', 'wpad', 'hstride', 'wstride'])

_Schedule = namedtuple('Schedule', ['vh', 'vw', 'vc', 'bc', 'unroll'])

# workloads of depthwise conv mobile net on imagenet
_WORKLOADS = [
_Workload(112, 112, 32, 1, 3, 3, 1, 1, 1, 1),
_Workload(112, 112, 64, 1, 3, 3, 1, 1, 2, 2),
_Workload( 56, 56, 128, 1, 3, 3, 1, 1, 1, 1),
_Workload( 56, 56, 128, 1, 3, 3, 1, 1, 2, 2),
_Workload( 28, 28, 256, 1, 3, 3, 1, 1, 1, 1),
_Workload( 28, 28, 256, 1, 3, 3, 1, 1, 2, 2),
_Workload( 14, 14, 512, 1, 3, 3, 1, 1, 1, 1),
_Workload( 14, 14, 512, 1, 3, 3, 1, 1, 2, 2),
_Workload( 14, 14, 1024, 1, 3, 3, 1, 1, 1, 1),
]

_SCHEDULES = [
_Schedule(2, 1, 4, 1, True),
_Schedule(2, 4, 4, 2, True),
_Schedule(2, 1, 4, 2, False),
_Schedule(2, 4, 4, 1, True),
_Schedule(4, 1, 4, 8, True),
_Schedule(1, 1, 4, 2, True),
_Schedule(1, 1, 8, 8, True),
_Schedule(1, 1, 4, 1, False),
_Schedule(2, 1, 4, 16, False),
]

def _get_workload(data, kernel, stride, padding):
_, C, IH, IW = [x.value for x in data.shape]
_, MT, KH, KW = [x.value for x in kernel.shape]
HPAD, WPAD, _, _ = get_pad_tuple(padding, kernel)
if isinstance(stride, (tuple, list)):
HSTR, WSTR = stride
else:
HSTR, WSTR = stride, stride
return _Workload(IH, IW, C, MT, KH, KW, HPAD, WPAD, HSTR, WSTR)


def _schedule(s, data, data_pad, kernel, output, last):
padding = infer_pad(data, data_pad)
if data_pad is None:
stride = infer_stride(data, kernel, output)
else:
stride = infer_stride(data_pad, kernel, output)
wkl = _get_workload(data, kernel, stride, padding)

if wkl not in _WORKLOADS:
return s

# use specified schedule
sch = _SCHEDULES[_WORKLOADS.index(wkl)]

H, W = wkl.height, wkl.width
CN = wkl.channel
MT = wkl.multiplier

HK, WK = wkl.hkernel, wkl.wkernel
HPAD, WPAD = wkl.hpad, wkl.wpad
HSTR, WSTR = wkl.hstride, wkl.wstride

VH, VW = sch.vh, sch.vw
BC = sch.bc
VC = sch.vc

TH = H + 2*HPAD
TW = W + 2*WPAD
OH = (H + 2*HPAD - HK) / HSTR + 1
OW = (W + 2*WPAD - WK) / WSTR + 1


A, B, C = data, kernel, output
A0 = data_pad
C0 = last

A1 = s.cache_read(A0, "global", C)
_, c, h, w = s[A1].op.axis
c, vc = s[A1].split(c, VC)
s[A1].reorder(c, h, w, vc)

A2 = s.cache_write(A1, 'global')
s[A0].compute_inline()
s[A1].compute_inline()

B0 = s.cache_read(B, "global", C)
c, m, h, w = s[B0].op.axis
c, vc = s[B0].split(c, VC)
s[B0].reorder(c, m, h, w, vc)

B1 = s.cache_write(B0, 'global')
s[B0].compute_inline()

_, c, h, w = s[C].op.axis
dh, dw = s[C].op.reduce_axis
c, vc = s[C].split(c, VC)
s[C].reorder(c, h, w, vc)


C0 = s.cache_write(C, 'global')
_, c, h, w, vc = s[C0].op.axis
dh, dw = s[C0].op.reduce_axis
oh, ow, ih, iw = s[C0].tile(h, w, VH, VW)
s[C0].reorder(c, oh, ow, dh, dw, ih, iw, vc)
if sch.unroll:
s[C0].unroll(iw)
s[C0].vectorize(vc)

oh, ow, ih, iw = s[C].tile(h, w, 2, 4)
s[C].reorder(oh, ow, dh, dw, ih, iw)
s[C].unroll(ih)
s[C].vectorize(iw)

# # s[C0].compute_at(s[C0], ow)
launch, c, _, _ = s[C].op.axis
s[C].pragma(launch, "parallel_launch_point")

s[C].parallel(c)
s[C].pragma(c, "parallel_launch_point")
s[C].pragma(c, "parallel_stride_pattern")
s[C].pragma(c, "parallel_barrier_when_finish")


s[C0].compute_at(s[C], launch)
_, c, h, w, vc = s[C0].op.axis
s[C0].parallel(c)
s[C0].pragma(c, "parallel_stride_pattern")
s[C0].pragma(c, "parallel_barrier_when_finish")


s[A2].compute_at(s[C0], oh)
# parallel(s[A2], s[A2].op.axis[1], BC)

# # s[B0].compute_at(s[C0], ow)
s[B1].compute_at(s[C], launch)
c, m, h, w, vc = s[B1].op.axis
s[B1].parallel(c)
s[B1].pragma(c, "parallel_stride_pattern")
s[B1].pragma(c, "parallel_barrier_when_finish")

return s


Expand Down