forked from apache/tvm
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[TOPI] Example for convolution in GPU (apache#212)
* [TOPI] Example for convolution * update conv ex * fix submodule HalideIR * update conv impl * python3 * minor fix * fix pylint error * Add test code * x * fix * fix * move python helper function into topi.testing * fix pylint
- Loading branch information
Showing
8 changed files
with
401 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,3 +9,4 @@ | |
from .math import * | ||
from . import nn | ||
from . import cuda | ||
from . import testing |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# pylint: disable=invalid-name | ||
"""Schedule for conv2d_hwcn with auto fusion""" | ||
import tvm | ||
|
||
|
||
def _schedule_conv2d_hwcn(op, sch): | ||
assert len(op.input_tensors) == 2 | ||
Apad = op.input_tensors[0] | ||
W = op.input_tensors[1] | ||
B = op.output(0) | ||
|
||
sch[Apad].compute_inline() | ||
AA = sch.cache_read(Apad, "shared", [B]) | ||
WW = sch.cache_read(W, "shared", [B]) | ||
AL = sch.cache_read(AA, "local", [B]) | ||
WL = sch.cache_read(WW, "local", [B]) | ||
|
||
if op in sch.outputs: | ||
Out = op.output(0) | ||
BL = sch.cache_write(Out, "local") | ||
else: | ||
Out = sch.outputs[0].output(0) | ||
sch[B].set_scope("local") | ||
BL = B | ||
|
||
tile = 8 | ||
num_thread = 8 | ||
block_factor = tile * num_thread | ||
step = 8 | ||
vthread = 2 | ||
|
||
block_x = tvm.thread_axis("blockIdx.x") | ||
block_y = tvm.thread_axis("blockIdx.y") | ||
block_z = tvm.thread_axis("blockIdx.z") | ||
thread_x = tvm.thread_axis((0, num_thread), "threadIdx.x") | ||
thread_y = tvm.thread_axis((0, num_thread), "threadIdx.y") | ||
thread_xz = tvm.thread_axis((0, vthread), "vthread", name="vx") | ||
thread_yz = tvm.thread_axis((0, vthread), "vthread", name="vy") | ||
|
||
hi, wi, fi, ni = sch[Out].op.axis | ||
bz = sch[Out].fuse(wi, hi) | ||
by, fi = sch[Out].split(fi, factor=block_factor) | ||
bx, ni = sch[Out].split(ni, factor=block_factor) | ||
tyz, fi = sch[Out].split(fi, nparts=vthread) | ||
txz, ni = sch[Out].split(ni, nparts=vthread) | ||
ty, fi = sch[Out].split(fi, nparts=num_thread) | ||
tx, ni = sch[Out].split(ni, nparts=num_thread) | ||
sch[Out].reorder(bz, by, bx, tyz, txz, ty, tx, fi, ni) | ||
sch[Out].bind(bz, block_z) | ||
sch[Out].bind(by, block_y) | ||
sch[Out].bind(bx, block_x) | ||
sch[Out].bind(tyz, thread_yz) | ||
sch[Out].bind(txz, thread_xz) | ||
sch[Out].bind(ty, thread_y) | ||
sch[Out].bind(tx, thread_x) | ||
|
||
# Schedule BL local write | ||
sch[BL].compute_at(sch[Out], tx) | ||
yi, xi, fi, ni = sch[BL].op.axis | ||
ry, rx, rc = sch[BL].op.reduce_axis | ||
rco, rci = sch[BL].split(rc, factor=step) | ||
sch[BL].reorder(rco, ry, rx, rci, fi, ni) | ||
fuse_index = sch[BL].fuse(rx, ry) | ||
fuse_index = sch[BL].fuse(fuse_index, rco) | ||
rx = fuse_index | ||
|
||
sch[AA].compute_at(sch[BL], rx) | ||
sch[WW].compute_at(sch[BL], rx) | ||
sch[AL].compute_at(sch[BL], rci) | ||
sch[WL].compute_at(sch[BL], rci) | ||
# Schedule for A's shared memory load | ||
yi, xi, ci, ni = sch[AA].op.axis | ||
ty, ci = sch[AA].split(ci, nparts=num_thread) | ||
tx, ni = sch[AA].split(ni, nparts=num_thread) | ||
_, ni = sch[AA].split(ni, factor=4) | ||
sch[AA].reorder(ty, tx, yi, xi, ci, ni) | ||
sch[AA].bind(ty, thread_y) | ||
sch[AA].bind(tx, thread_x) | ||
sch[AA].vectorize(ni) | ||
# Schedule for W's shared memory load | ||
yi, xi, ci, fi = sch[WW].op.axis | ||
ty, ci = sch[WW].split(ci, nparts=num_thread) | ||
tx, fi = sch[WW].split(fi, nparts=num_thread) | ||
_, fi = sch[WW].split(fi, factor=4) | ||
sch[WW].reorder(ty, tx, yi, xi, ci, fi) | ||
sch[WW].bind(ty, thread_y) | ||
sch[WW].bind(tx, thread_x) | ||
sch[WW].vectorize(fi) | ||
|
||
return sch | ||
|
||
|
||
def schedule_conv2d_hwcn_map(op): | ||
"""Schedule for conv2d_hwcn map ops. | ||
Parameters | ||
---------- | ||
op: tvm.tensor.Operation | ||
The symbolic description of the operation, should be conv2d_hwcn or | ||
conv2d_hwcn followed by a sequence of one-to-one-mapping operators. | ||
Returns | ||
------- | ||
sch: Schedule | ||
The computation schedule for the op. | ||
""" | ||
def traverse(operator): | ||
if operator.tag == 'ewise' or operator.tag == 'scale_shift': | ||
if operator not in sch.outputs: | ||
sch[operator].compute_inline() | ||
for tensor in operator.input_tensors: | ||
if tensor.op.input_tensors: | ||
traverse(tensor.op) | ||
elif operator.tag == 'conv2d_hwcn': | ||
_schedule_conv2d_hwcn(operator, sch) | ||
else: | ||
raise RuntimeError("Unsupported operator: %s" % operator.tag) | ||
|
||
sch = tvm.create_schedule(op) | ||
traverse(op) | ||
return sch |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,7 @@ | ||
"""TOPI Testing Util functions. | ||
Used to verify the correctness of operators in TOPI . | ||
""" | ||
from __future__ import absolute_import as _abs | ||
|
||
from .conv2d_hwcn_python import conv2d_hwcn_python |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# pylint: disable=invalid-name, line-too-long, unused-variable | ||
"""Convolution in python""" | ||
import numpy as np | ||
import scipy.signal | ||
|
||
|
||
def conv2d_hwcn_python(a_np, w_np, stride, padding): | ||
"""Convolution operator in HWCN layout. | ||
Parameters | ||
---------- | ||
a_np : numpy.ndarray | ||
4-D with shape [in_height, in_width, in_channel, batch] | ||
w_np : numpy.ndarray | ||
4-D with shape [filter_height, filter_width, in_channel, num_filter] | ||
stride : int or a list/tuple of two ints | ||
Stride size, or [stride_height, stride_width] | ||
padding : int or str | ||
Padding size, or ['VALID', 'SAME'] | ||
Returns | ||
------- | ||
b_np : np.ndarray | ||
4-D with shape [out_height, out_width, out_channel, batch] | ||
""" | ||
in_height, in_width, in_channel, batch = a_np.shape | ||
kernel_h, kernel_w, _, num_filter = w_np.shape | ||
if isinstance(stride, int): | ||
stride_h = stride_w = stride | ||
else: | ||
stride_h, stride_w = stride | ||
if isinstance(padding, int): | ||
pad_h = pad_w = padding * 2 | ||
elif padding == 'VALID': | ||
pad_h = 0 | ||
pad_w = 0 | ||
else: # 'SAME' | ||
pad_h = kernel_h - 1 | ||
pad_w = kernel_w - 1 | ||
pad_top = int(np.ceil(float(pad_h) / 2)) | ||
pad_bottom = pad_h - pad_top | ||
pad_left = int(np.ceil(float(pad_w) / 2)) | ||
pad_right = pad_w - pad_left | ||
# compute the output shape | ||
out_channel = num_filter | ||
out_height = (in_height - kernel_h + pad_h) // stride_h + 1 | ||
out_width = (in_width - kernel_w + pad_w) // stride_w + 1 | ||
# change the layout from HWCN to NCHW | ||
at = a_np.transpose((3, 2, 0, 1)) | ||
wt = w_np.transpose((3, 2, 0, 1)) | ||
bt = np.zeros((batch, out_channel, out_height, out_width)) | ||
# computation | ||
for n in range(batch): | ||
for f in range(out_channel): | ||
for c in range(in_channel): | ||
if pad_h > 0: | ||
apad = np.zeros((in_height + pad_h, in_width + pad_w)) | ||
apad[pad_top:-pad_bottom, pad_left:-pad_right] = at[n, c] | ||
else: | ||
apad = at[n, c] | ||
out = scipy.signal.convolve2d( | ||
apad, np.rot90(np.rot90(wt[f, c])), mode='valid') | ||
bt[n, f] += out[::stride, ::stride] | ||
return bt.transpose((2, 3, 1, 0)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,78 @@ | ||
"""Example code to do convolution.""" | ||
import os | ||
import numpy as np | ||
import scipy.signal | ||
import tvm | ||
from tvm.contrib import nvcc | ||
import topi | ||
from topi.nn.util import get_const_tuple | ||
|
||
TASK = "conv2d_hwcn_map" | ||
USE_MANUAL_CODE = False | ||
|
||
@tvm.register_func | ||
def tvm_callback_cuda_compile(code): | ||
ptx = nvcc.compile_cuda(code, target="ptx", options=["-arch=sm_52"]) | ||
return ptx | ||
|
||
def write_code(code, fname): | ||
with open(fname, "w") as f: | ||
f.write(code) | ||
|
||
@tvm.register_func | ||
def tvm_callback_cuda_postproc(code): | ||
if not os.path.exists("perf"): | ||
os.mkdir("perf") | ||
write_code(code, "perf/%s_generated.cu" % TASK) | ||
if USE_MANUAL_CODE: | ||
code = open("perf/%s_manual.cu" % TASK).read() | ||
return code | ||
|
||
|
||
def test_conv2d_hwcn_map(): | ||
batch = 64 | ||
in_channel = 128 | ||
in_height = 16 | ||
in_width = 16 | ||
num_filter = 128 | ||
kernel = 3 | ||
stride = 2 | ||
padding = 'SAME' | ||
|
||
A = tvm.placeholder((in_height, in_width, in_channel, batch), name='A') | ||
W = tvm.placeholder((kernel, kernel, in_channel, num_filter), name='W') | ||
B = topi.nn.conv2d_hwcn(A, W, stride, padding) | ||
C = topi.nn.relu(B) | ||
s1 = topi.cuda.schedule_conv2d_hwcn_map(B.op) | ||
s2 = topi.cuda.schedule_conv2d_hwcn_map(C.op) | ||
|
||
a_np = np.random.uniform(size=get_const_tuple(A.shape)).astype(A.dtype) | ||
w_np = np.random.uniform(size=get_const_tuple(W.shape)).astype(W.dtype) | ||
b_np = topi.testing.conv2d_hwcn_python(a_np, w_np, stride, padding) | ||
c_np = np.maximum(b_np, 0) | ||
|
||
def check_device(device): | ||
if not tvm.module.enabled(device): | ||
print("Skip because %s is not enabled" % device) | ||
return | ||
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0) | ||
a = tvm.nd.array(a_np, ctx) | ||
w = tvm.nd.array(w_np, ctx) | ||
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx) | ||
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx) | ||
with tvm.build_config(auto_unroll_max_step=32, | ||
auto_unroll_min_depth=0, | ||
unroll_explicit=False): | ||
func1 = tvm.build(s1, [A, W, B], device) | ||
func1(a, w, b) | ||
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5) | ||
func2 = tvm.build(s2, [A, W, C], device) | ||
func2(a, w, c) | ||
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5) | ||
|
||
for device in ['cuda', 'opencl']: | ||
check_device(device) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_conv2d_hwcn_map() |
Oops, something went wrong.