Skip to content

Commit c3cac46

Browse files
masahitqchen
authored andcommitted
enable rocm target for topi/recipes. add timing util to gemm test. (#554)
1 parent 592a1f6 commit c3cac46

File tree

3 files changed

+34
-23
lines changed

3 files changed

+34
-23
lines changed

topi/recipe/conv/depthwise_conv2d_test.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def check_device(device):
6969
if not tvm.module.enabled(device):
7070
print("Skip because %s is not enabled" % device)
7171
return
72-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
72+
ctx = tvm.context(device, 0)
7373
# Build the kernel
7474
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
7575
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
@@ -111,12 +111,13 @@ def check_device(device):
111111
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
112112
print("success")
113113

114-
with tvm.build_config(auto_unroll_max_step=32,
115-
auto_unroll_min_depth=0,
116-
unroll_explicit=False,
117-
detect_global_barrier=False,
118-
restricted_func=True):
119-
check_device("cuda")
114+
for device in ['cuda', 'opencl', 'rocm']:
115+
with tvm.build_config(auto_unroll_max_step=32,
116+
auto_unroll_min_depth=0,
117+
unroll_explicit=device == 'rocm',
118+
detect_global_barrier=False,
119+
restricted_func=True):
120+
check_device(device)
120121

121122
def test_depthwise_conv2d_nhwc():
122123
"""You may test different settings."""
@@ -159,7 +160,7 @@ def check_device(device):
159160
if not tvm.module.enabled(device):
160161
print("Skip because %s is not enabled" % device)
161162
return
162-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
163+
ctx = tvm.context(device, 0)
163164
# Build the kernel
164165
f1 = tvm.build(s1, [Input, Filter, DepthwiseConv2d], device)
165166
f2 = tvm.build(s2, [Input, Filter, Scale, Shift, ScaleShift], device)
@@ -200,12 +201,13 @@ def check_device(device):
200201
np.testing.assert_allclose(relu_tvm.asnumpy(), relu_scipy, rtol=1e-5)
201202
print("success")
202203

203-
with tvm.build_config(auto_unroll_max_step=32,
204-
auto_unroll_min_depth=0,
205-
unroll_explicit=False,
206-
detect_global_barrier=False,
207-
restricted_func=True):
208-
check_device("cuda")
204+
for device in ['cuda', 'opencl', 'rocm']:
205+
with tvm.build_config(auto_unroll_max_step=32,
206+
auto_unroll_min_depth=0,
207+
unroll_explicit=device == 'rocm',
208+
detect_global_barrier=False,
209+
restricted_func=True):
210+
check_device(device)
209211

210212
if __name__ == "__main__":
211213
test_depthwise_conv2d_nchw()

topi/recipe/conv/test_conv2d_hwcn_map.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import tvm
66
from tvm.contrib import nvcc
77
import topi
8-
from topi.nn.util import get_const_tuple
8+
from topi.util import get_const_tuple
99

1010
TASK = "conv2d_hwcn_map"
1111
USE_MANUAL_CODE = False
@@ -55,22 +55,22 @@ def check_device(device):
5555
if not tvm.module.enabled(device):
5656
print("Skip because %s is not enabled" % device)
5757
return
58-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
58+
ctx = tvm.context(device, 0)
5959
a = tvm.nd.array(a_np, ctx)
6060
w = tvm.nd.array(w_np, ctx)
6161
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), ctx)
6262
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
6363
with tvm.build_config(auto_unroll_max_step=32,
6464
auto_unroll_min_depth=0,
65-
unroll_explicit=False):
65+
unroll_explicit=device == 'rocm'):
6666
func1 = tvm.build(s1, [A, W, B], device)
6767
func1(a, w, b)
6868
np.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-5)
6969
func2 = tvm.build(s2, [A, W, C], device)
7070
func2(a, w, c)
7171
np.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
7272

73-
for device in ['cuda', 'opencl']:
73+
for device in ['cuda', 'opencl', 'rocm']:
7474
check_device(device)
7575

7676

topi/recipe/gemm/cuda_gemm_square.py

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,12 @@ def test_gemm():
100100
s[BB].double_buffer()
101101
# correctness
102102
def check_device(device):
103+
print("Device %s" % device)
103104
if not tvm.module.enabled(device):
104105
print("Skip because %s is not enabled" % device)
105106
return
106107
f = tvm.build(s, [A, B, C], device)
107-
ctx = tvm.gpu(0) if device == "cuda" else tvm.cl(0)
108+
ctx = tvm.context(device, 0)
108109
# launch the kernel.
109110
n, m, l = nn, nn, nn
110111
a_np = np.random.uniform(size=(n, l)).astype(A.dtype)
@@ -117,10 +118,18 @@ def check_device(device):
117118
np.testing.assert_allclose(
118119
c.asnumpy(), np.dot(b_np.T, a_np), rtol=1e-5)
119120

120-
with tvm.build_config(auto_unroll_max_step=32,
121-
auto_unroll_min_depth=0,
122-
unroll_explicit=False):
123-
check_device("cuda")
121+
num_flops = 2 * nn * nn * nn
122+
num_runs = 10
123+
timer_f = f.time_evaluator(f.entry_name, ctx, number=num_runs)
124+
t = timer_f(a, b, c).mean
125+
GFLOPS = num_flops / (t * 1e3) / 1e6
126+
print("average time cost of %d runs = %g ms, %g GFLOPS." % (num_runs, t * 1e3, GFLOPS))
127+
128+
for device in ['cuda', 'opencl', 'rocm']:
129+
with tvm.build_config(auto_unroll_max_step=32,
130+
auto_unroll_min_depth=0,
131+
unroll_explicit=device == 'rocm'):
132+
check_device(device)
124133

125134
if __name__ == "__main__":
126135
test_gemm()

0 commit comments

Comments
 (0)