Skip to content

Commit 4bb9b54

Browse files
author
Siyuan Feng
committed
extend into BatchMatMul
1 parent d6795f1 commit 4bb9b54

File tree

1 file changed

+24
-21
lines changed

1 file changed

+24
-21
lines changed

tests/python/unittest/test_schedule_tensor_core.py

Lines changed: 24 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -98,20 +98,21 @@ def intrin_func(ins, outs):
9898
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
9999

100100

101-
def test_tensor_core_gemm():
102-
n = 4096
101+
def test_tensor_core_batch_matmal():
102+
batch_size = 20
103+
n = 2048
103104
m, l = n, n
104105
assert (n % 16 == 0)
105106
assert (m % 16 == 0)
106107
assert (l % 16 == 0)
107108
nn, mm, ll = n // 16, m // 16, l // 16
108-
A = tvm.placeholder((nn, ll, 16, 16), name='A', dtype='float16')
109-
B = tvm.placeholder((ll, mm, 16, 16), name='B', dtype='float16')
109+
A = tvm.placeholder((batch_size, nn, ll, 16, 16), name='A', dtype='float16')
110+
B = tvm.placeholder((batch_size, ll, mm, 16, 16), name='B', dtype='float16')
110111
k1 = tvm.reduce_axis((0, ll), name='k1')
111112
k2 = tvm.reduce_axis((0, 16), name='k2')
112-
C = tvm.compute((nn, mm, 16, 16),
113-
lambda i, j, ii, jj:
114-
tvm.sum(A[i, k1, ii, k2].astype('float') * B[k1, j, k2, jj].astype('float'), axis=[k1, k2]),
113+
C = tvm.compute((batch_size, nn, mm, 16, 16),
114+
lambda b, i, j, ii, jj:
115+
tvm.sum(A[b, i, k1, ii, k2].astype('float') * B[b, k1, j, k2, jj].astype('float'), axis=[k1, k2]),
115116
name='Fragment_C')
116117
s = tvm.create_schedule(C.op)
117118

@@ -125,6 +126,7 @@ def test_tensor_core_gemm():
125126

126127
block_x = tvm.thread_axis('blockIdx.x')
127128
block_y = tvm.thread_axis('blockIdx.y')
129+
block_z = tvm.thread_axis('blockIdx.z')
128130
thread_x = tvm.thread_axis('threadIdx.x')
129131
thread_y = tvm.thread_axis('threadIdx.y')
130132
thread_z = tvm.thread_axis('threadIdx.z')
@@ -135,19 +137,20 @@ def test_tensor_core_gemm():
135137
BF = s.cache_read(BS, 'wmma.matrix_b', [C])
136138
CF = s.cache_write(C, 'wmma.accumulator')
137139

138-
i, j, kernel_i, kernel_j = s[C].op.axis
140+
b, i, j, kernel_i, kernel_j = s[C].op.axis
139141
i, ii = s[C].split(i, factor=warp_row_tiles)
140142
block_i, i = s[C].split(i, factor=block_row_warps)
141143
j, jj = s[C].split(j, factor=warp_col_tiles)
142144
block_j, j = s[C].split(j, factor=block_col_warps)
143145
s[C].reorder(block_i, block_j, i, j, ii, jj, kernel_i, kernel_j)
146+
s[C].bind(b, block_z)
144147
s[C].bind(block_i, block_x)
145148
s[C].bind(block_j, block_y)
146149
s[C].bind(i, thread_y)
147150
s[C].bind(j, thread_z)
148151

149152
s[CF].compute_at(s[C], j)
150-
warp_i, warp_j, _i, _j = s[CF].op.axis
153+
b, warp_i, warp_j, _i, _j = s[CF].op.axis
151154
k, _k = CF.op.reduce_axis
152155
ko, ki = s[CF].split(k, factor=chunk)
153156
s[CF].reorder(ko, ki, warp_i, warp_j, _i, _j, _k)
@@ -156,7 +159,7 @@ def test_tensor_core_gemm():
156159
s[BF].compute_at(s[CF], ki)
157160

158161
s[AS].compute_at(s[CF], ko)
159-
xo, yo, xi, yi = AS.op.axis
162+
b, xo, yo, xi, yi = AS.op.axis
160163
tx, xo = s[AS].split(xo, nparts=block_row_warps)
161164
ty, yo = s[AS].split(yo, nparts=block_col_warps)
162165
t = s[AS].fuse(xi, yi)
@@ -167,7 +170,7 @@ def test_tensor_core_gemm():
167170
s[AS].vectorize(ti)
168171

169172
s[BS].compute_at(s[CF], ko)
170-
xo, yo, xi, yi = BS.op.axis
173+
b, xo, yo, xi, yi = BS.op.axis
171174
tx, xo = s[BS].split(xo, nparts=block_row_warps)
172175
ty, yo = s[BS].split(yo, nparts=block_col_warps)
173176
t = s[BS].fuse(xi, yi)
@@ -184,23 +187,23 @@ def test_tensor_core_gemm():
184187
func = tvm.build(s, [A, B, C], 'cuda')
185188

186189
ctx = tvm.gpu(0)
187-
a_np = np.random.uniform(size=(nn, nn, 16, 16)).astype(A.dtype)
188-
b_np = np.random.uniform(size=(nn, nn, 16, 16)).astype(B.dtype)
190+
a_np = np.random.uniform(size=(batch_size, nn, nn, 16, 16)).astype(A.dtype)
191+
b_np = np.random.uniform(size=(batch_size, nn, nn, 16, 16)).astype(B.dtype)
189192
a = tvm.nd.array(a_np, ctx)
190193
b = tvm.nd.array(b_np, ctx)
191-
c = tvm.nd.array(np.zeros((nn, nn, 16, 16), dtype=C.dtype), ctx)
194+
c = tvm.nd.array(np.zeros((batch_size, nn, nn, 16, 16), dtype=C.dtype), ctx)
192195
evaluator = func.time_evaluator(func.entry_name, ctx, number=3)
193196
print('gemm with tensor core: %f ms' % (evaluator(a, b, c).mean * 1e3))
194197

195198
if VERIFY:
196199
func(a, b, c)
197-
a_np = a_np.transpose(0, 2, 1, 3).reshape(n, n)
198-
b_np = b_np.transpose(0, 2, 1, 3).reshape(n, n)
199-
c_np = c.asnumpy().transpose(0, 2, 1, 3).reshape(n, n)
200-
np.testing.assert_allclose(c_np, np.dot(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4)
200+
a_np = a_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
201+
b_np = b_np.transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
202+
c_np = c.asnumpy().transpose((0, 1, 3, 2, 4)).reshape(batch_size, n, n)
203+
np.testing.assert_allclose(c_np, np.matmul(a_np.astype(C.dtype), b_np.astype(C.dtype)), rtol=1e-4, atol=1e-4)
201204

202205

203-
def test_tensor_core_conv():
206+
def test_tensor_core_batch_conv():
204207
# The sizes of inputs and filters
205208
batch_size = 256
206209
height = 14
@@ -364,5 +367,5 @@ def test_tensor_core_conv():
364367
if not nvcc.have_tensorcore(ctx.compute_version):
365368
print("skip because gpu does not support tensor core")
366369
else:
367-
test_tensor_core_gemm()
368-
test_tensor_core_conv()
370+
test_tensor_core_batch_matmal()
371+
test_tensor_core_batch_conv()

0 commit comments

Comments
 (0)