Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Example for convolution in GPU #212

Merged
merged 13 commits into from
Jul 19, 2017
Merged
Prev Previous commit
Next Next commit
update conv impl
  • Loading branch information
icemelon committed Jul 18, 2017
commit 57f13282f5364ce7dcee6fe1f34bd24b1a3382bb
1 change: 1 addition & 0 deletions topi/python/topi/cuda/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
"""CUDA specific declaration and schedules."""
from __future__ import absolute_import as _abs

from .conv2d_hwcn_map import schedule_conv2d_hwcn_map
from .depthwise_conv2d_map import schedule_depthwise_conv2d_map
121 changes: 121 additions & 0 deletions topi/python/topi/cuda/conv2d_hwcn_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""Schedule for conv2d_hwcn with auto fusion"""
import tvm
from ..nn.util import get_const_tuple


def _schedule_conv2d_hwcn(op, sch):
assert len(op.input_tensors) == 2
Apad = op.input_tensors[0]
W = op.input_tensors[1]
B = op.output(0)

sch[Apad].compute_inline()
AA = sch.cache_read(Apad, "shared", [B])
WW = sch.cache_read(W, "shared", [B])
AL = sch.cache_read(AA, "local", [B])
WL = sch.cache_read(WW, "local", [B])

if op in sch.outputs:
Out = op.output(0)
BL = sch.cache_write(Out, "local")
else:
Out = sch.outputs[0].output(0)
sch[B].set_scope("local")
BL = B

tile = 8
num_thread = 8
block_factor = tile * num_thread
step = 8
vthread = 2

block_x = tvm.thread_axis("blockIdx.x")
block_y = tvm.thread_axis("blockIdx.y")
block_z = tvm.thread_axis("blockIdx.z")
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x")
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y")
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx")
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy")

hi, wi, fi, ni = sch[Out].op.axis
bz = sch[Out].fuse(wi, hi)
by, fi = sch[Out].split(fi, factor=block_factor)
bx, ni = sch[Out].split(ni, factor=block_factor)
tyz, fi = sch[Out].split(fi, nparts=vthread)
txz, ni = sch[Out].split(ni, nparts=vthread)
ty, fi = sch[Out].split(fi, nparts=num_thread)
tx, ni = sch[Out].split(ni, nparts=num_thread)
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni)
sch[Out].bind(bz, block_z)
sch[Out].bind(by, block_y)
sch[Out].bind(bx, block_x)
sch[Out].bind(tyz, thread_yz)
sch[Out].bind(txz, thread_xz)
sch[Out].bind(ty, thread_y)
sch[Out].bind(tx, thread_x)

# Schedule BL local write
sch[BL].compute_at(sch[Out], tx)
yi, xi, fi, ni = sch[BL].op.axis
ry, rx, rc = sch[BL].op.reduce_axis
rco, rci = sch[BL].split(rc, factor=step)
sch[BL].reorder(rco, ry, rx, rci, fi, ni)
fuse_index = sch[BL].fuse(rx, ry)
fuse_index = sch[BL].fuse(fuse_index, rco)
rx = fuse_index

sch[AA].compute_at(sch[BL], rx)
sch[WW].compute_at(sch[BL], rx)
sch[AL].compute_at(sch[BL], rci)
sch[WL].compute_at(sch[BL], rci)
# Schedule for A's shared memory load
yi, xi, ci, ni = sch[AA].op.axis
ty, ci = sch[AA].split(ci, nparts=num_thread)
tx, ni = sch[AA].split(ni, nparts=num_thread)
_, ni = sch[AA].split(ni, factor=4)
sch[AA].reorder(ty, tx, yi, xi, ci, ni)
sch[AA].bind(ty, thread_y)
sch[AA].bind(tx, thread_x)
sch[AA].vectorize(ni)
# Schedule for W's shared memory load
yi, xi, ci, fi = sch[WW].op.axis
ty, ci = sch[WW].split(ci, nparts=num_thread)
tx, fi = sch[WW].split(fi, nparts=num_thread)
_, fi = sch[WW].split(fi, factor=4)
sch[WW].reorder(ty, tx, yi, xi, ci, fi)
sch[WW].bind(ty, thread_y)
sch[WW].bind(tx, thread_x)
sch[WW].vectorize(fi)

return sch


def schedule_conv2d_hwcn_map(op):
"""Schedule for conv2d_hwcn map ops.

Parameters
----------
op: tvm.tensor.Operation
The symbolic description of the operation, should be conv2d_hwcn or
conv2d_hwcn followed by a sequence of one-to-one-mapping operators.

Returns
-------
sch: Schedule
The computation schedule for the op.
"""
def traverse(operator):
if operator.tag == 'ewise' or operator.tag == 'scale_shift':
if operator not in sch.outputs:
sch[operator].compute_inline()
for tensor in operator.input_tensors:
if tensor.op.input_tensors:
traverse(tensor.op)
elif operator.tag == 'conv2d_hwcn':
_schedule_conv2d_hwcn(operator, sch)
else:
raise RuntimeError("Unsupported operator: %s" % operator.tag)

sch = tvm.create_schedule(op)
traverse(op)
return sch
67 changes: 67 additions & 0 deletions topi/python/topi/nn/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,73 @@
import numpy as np
from .util import get_const_tuple


@tvm.tag_scope(tag="conv2d_hwcn")
def conv2d_hwcn(Input, Filter, stride, padding):
"""Depthwise convolution operator.

Parameters
----------
Input : tvm.Tensor
4-D with shape [in_height, in_width, in_channel, batch]

Filter : tvm.Tensor
4-D with shape [filter_height, filter_width, in_channel, num_filter]

stride : int or a list/tuple of two ints
Stride size, or [stride_height, stride_width]

padding : int or str
Padding size or ['valid', 'same']

Returns
-------
Output : tvm.Tensor
4-D with shape [out_height, out_width, out_channel, batch]
"""
assert isinstance(stride, int) or len(stride) == 2
assert isinstance(padding, int) or padding in ['valid', 'same']
in_height, in_width, in_channel, batch = get_const_tuple(Input.shape)
kernel_h, kernel_w, channel, num_filter = get_const_tuple(Filter.shape)
if isinstance(stride, int):
stride_h = stride_w = stride
else:
stride_h, stride_w = stride
# compute the padding size
if isinstance(padding, int):
pad_h = pad_w = padding * 2
elif padding == 'VALID':
pad_h = 0
pad_w = 0
else: # "SAME"
pad_h = kernel_h - 1
pad_w = kernel_w - 1
pad_top = int(np.ceil(float(pad_h) / 2))
pad_left = int(np.ceil(float(pad_w) / 2))
# compute the output shape
out_channel = num_filter
out_height = (in_height - kernel_h + pad_h) / stride_h + 1
out_width = (in_width - kernel_w + pad_w) / stride_w + 1
# compute graph
PaddedInput = tvm.compute(
(in_height + pad_h, in_width + pad_w, in_channel, batch),
lambda yy, xx, cc, nn: tvm.select(
tvm.all(yy >= pad_top, yy - pad_top < in_height,
xx >= pad_left, xx - pad_left < in_width),
Input[yy - pad_top, xx - pad_left, cc, nn], tvm.const(0.)),
name='PaddedInput')
rc = tvm.reduce_axis((0, in_channel), name='rc')
ry = tvm.reduce_axis((0, kernel_h), name='ry')
rx = tvm.reduce_axis((0, kernel_w), name='rx')
Output = tvm.compute(
(out_height, out_width, out_channel, batch),
lambda yy, xx, ff, nn: tvm.sum(
PaddedInput[yy * stride_h + ry, xx * stride_w + rx, rc, nn] * Filter[ry, rx, rc, ff],
axis=[ry, rx, rc]),
name='Conv2dOutput')
return Output


@tvm.tag_scope(tag="depthwise_conv2d")
def depthwise_conv2d(Input, Filter, Stride, padding):
"""Depthwise convolution operator.
Expand Down
Loading