Skip to content

Commit 3db8d64

Browse files
author
Siyuan Feng
committed
add TensorCore tutorial
1 parent 0c6b378 commit 3db8d64

File tree

2 files changed

+354
-6
lines changed

2 files changed

+354
-6
lines changed

tests/python/unittest/test_schedule_tensor_core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from topi.testing import conv2d_nhwc_python
2020
from tvm.contrib import nvcc
2121

22-
VERIFY = False
22+
VERIFY = True
2323

2424

2525
def intrin_wmma_load_matrix(scope):
@@ -99,8 +99,8 @@ def intrin_func(ins, outs):
9999

100100

101101
def test_tensor_core_batch_matmal():
102-
batch_size = 20
103-
n = 2048
102+
batch_size = 4
103+
n = 512
104104
m, l = n, n
105105
assert (n % 16 == 0)
106106
assert (m % 16 == 0)
@@ -205,11 +205,11 @@ def test_tensor_core_batch_matmal():
205205

206206
def test_tensor_core_batch_conv():
207207
# The sizes of inputs and filters
208-
batch_size = 256
208+
batch_size = 32
209209
height = 14
210210
width = 14
211-
in_channels = 256
212-
out_channels = 512
211+
in_channels = 32
212+
out_channels = 64
213213
kernel_h = 3
214214
kernel_w = 3
215215
pad_h = 1
Lines changed: 348 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,348 @@
1+
# Licensed to the Apache Software Foundation (ASF) under one
2+
# or more contributor license agreements. See the NOTICE file
3+
# distributed with this work for additional information
4+
# regarding copyright ownership. The ASF licenses this file
5+
# to you under the Apache License, Version 2.0 (the
6+
# "License"); you may not use this file except in compliance
7+
# with the License. You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing,
12+
# software distributed under the License is distributed on an
13+
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
# KIND, either express or implied. See the License for the
15+
# specific language governing permissions and limitations
16+
# under the License.
17+
"""
18+
.. _opt-conv-tensorcore:
19+
20+
How to optimize convolution using TensorCores
21+
==================================
22+
**Author**: `Siyuan Feng <https://github.com/Hzfengsy>`_
23+
24+
In this tutorial, we will demonstrate how to write a high performance convolution
25+
schedule using TensorCores in TVM. In this example, we assume the input to
26+
convolution has a large batch. We strongly recommend covering the :ref:`opt-conv-gpu` tutorial first.
27+
28+
"""
29+
30+
################################################################
31+
# TensorCore Introduction
32+
# -------------------------
33+
# Each Tensor Core provides a 4x4x4 matrix processing array that operates
34+
# :code:`D = A * B + C`, where A, B, C and D are 4x4 matrices as Figure shows.
35+
# The matrix multiplication inputs A and B are FP16 matrices, while the accumulation
36+
# matrices C and D may be FP16 or FP32 matrices.
37+
#
38+
# However, CUDA programmers can only use warp-level primitive
39+
# :code:`wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` to perform
40+
# 16x16x16 half-precision matrix multiplication on tensor cores. Before invoking
41+
# the matrix multiplication, programmers must load data from memory into registers
42+
# with primitive :code:`wmma::load_matrix_sync`, explicitly. The NVCC compiler translates
43+
# that primitive into multiple memory load instructions. At run time, every thread loads
44+
# 16 elements from matrix A and 16 elements from B.
45+
46+
################################################################
47+
# Preparation and Algorithm
48+
# --------------------------
49+
# We use the fixed size for input tensors with 256 channels and 14 x 14 dimensions.
50+
# The batch size is 256. Convolution filters contain 512 filters of size 3 x 3.
51+
# We use stride size 1 and padding size 1 for the convolution. In the example, we use
52+
# NHWCnc memory layout.The following code defines the convolution algorithm in TVM.
53+
54+
import tvm
55+
import numpy as np
56+
from tvm.contrib import nvcc
57+
58+
# The sizes of inputs and filters
59+
batch_size = 256
60+
height = 14
61+
width = 14
62+
in_channels = 256
63+
out_channels = 512
64+
kernel_h = 3
65+
kernel_w = 3
66+
pad_h = 1
67+
pad_w = 1
68+
stride_h = 1
69+
stride_w = 1
70+
71+
# TensorCore shape
72+
block_size = 16
73+
74+
assert (batch_size % block_size == 0)
75+
assert (in_channels % block_size == 0)
76+
assert (out_channels % block_size == 0)
77+
78+
# Input feature map: (N, H, W, IC, n, ic)
79+
data_shape = (batch_size // block_size,
80+
height,
81+
width,
82+
in_channels // block_size,
83+
block_size,
84+
block_size)
85+
# Kernel: (H, W, IC, OC, ic, oc)
86+
kernel_shape = (kernel_h,
87+
kernel_w,
88+
in_channels // block_size,
89+
out_channels // block_size,
90+
block_size,
91+
block_size)
92+
# Output feature map: (N, H, W, OC, n, oc)
93+
output_shape = (batch_size // block_size,
94+
height,
95+
width,
96+
out_channels // block_size,
97+
block_size,
98+
block_size)
99+
100+
# Reduction axes
101+
kh = tvm.reduce_axis((0, kernel_h), name='kh')
102+
kw = tvm.reduce_axis((0, kernel_w), name='kw')
103+
ic = tvm.reduce_axis((0, in_channels // block_size), name='ic')
104+
ii = tvm.reduce_axis((0, block_size), name='ii')
105+
106+
# Algorithm
107+
A = tvm.placeholder(data_shape, name='A', dtype="float16")
108+
W = tvm.placeholder(kernel_shape, name='W', dtype="float16")
109+
Apad = tvm.compute(
110+
(batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size,
111+
block_size),
112+
lambda n, h, w, i, nn, ii: tvm.if_then_else(
113+
tvm.all(h >= pad_h, h - pad_h < height,
114+
w >= pad_w, w - pad_w < width),
115+
A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")),
116+
name='Apad')
117+
Conv = tvm.compute(output_shape,
118+
lambda n, h, w, o, nn, oo: tvm.sum(
119+
Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") *
120+
W[kh, kw, ic, o, ii, oo].astype("float32"),
121+
axis=[ic, kh, kw, ii]),
122+
name="Conv")
123+
124+
s = tvm.create_schedule(Conv.op)
125+
s[Apad].compute_inline()
126+
127+
###############################################################################
128+
# Memory Scope
129+
# ----------------
130+
#
131+
# In traditional GPU schedule, we have global, shared and local memory scope.
132+
# To support TensorCores, we add another three special memory scope: :code:`wmma.matrix_a`,
133+
# :code:`wmma.matrix_b` and :code:`wmma.accumulator`. On hardware, all fragments scope
134+
# stores at the on-chip registers level, the same place with local memory.
135+
136+
# Designate the memory hierarchy
137+
AS = s.cache_read(Apad, 'shared', [Conv])
138+
WS = s.cache_read(W, 'shared', [Conv])
139+
AF = s.cache_read(AS, 'wmma.matrix_a', [Conv])
140+
WF = s.cache_read(WS, 'wmma.matrix_b', [Conv])
141+
ConvF = s.cache_write(Conv, 'wmma.accumulator')
142+
143+
###############################################################################
144+
# Define Tensor Intrinsic
145+
# In fact, TensorCore is a special hardware operation. So, we can just use tensorize
146+
# to replace a unit of computation with the TensorCore instruction. The first thing is
147+
# that we need to define tensor intrinsic.
148+
#
149+
# There are four basic operation in TensorCore: :code:`fill_fragment`, :code:`load_matrix`,
150+
# :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync`
151+
# are both used in matrix multiplication, so we can just write following three intrinsics.
152+
153+
def intrin_wmma_load_matrix(scope):
154+
n = 16
155+
A = tvm.placeholder((n, n), name='A', dtype='float16')
156+
BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256)
157+
C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
158+
BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256)
159+
160+
def intrin_func(ins, outs):
161+
ib = tvm.ir_builder.create()
162+
163+
BA = ins[0]
164+
BC = outs[0]
165+
ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync',
166+
BC.data, n, n, n, BC.elem_offset // 256,
167+
BA.access_ptr('r'), n, 'row_major'))
168+
return ib.get()
169+
170+
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
171+
172+
173+
def intrin_wmma_gemm():
174+
n = 16
175+
A = tvm.placeholder((n, n), name='A', dtype='float16')
176+
B = tvm.placeholder((n, n), name='B', dtype='float16')
177+
k = tvm.reduce_axis((0, n), name="k")
178+
C = tvm.compute((n, n),
179+
lambda ii, jj:
180+
tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k),
181+
name='C')
182+
BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256)
183+
BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256)
184+
BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256)
185+
186+
def intrin_func(ins, outs):
187+
BA, BB = ins
188+
BC, = outs
189+
190+
def init():
191+
ib = tvm.ir_builder.create()
192+
ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0))
193+
return ib.get()
194+
195+
def update():
196+
ib = tvm.ir_builder.create()
197+
ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync',
198+
BC.data, BC.elem_offset // 256,
199+
BA.data, BA.elem_offset // 256,
200+
BB.data, BB.elem_offset // 256,
201+
BC.data, BC.elem_offset // 256))
202+
return ib.get()
203+
204+
return update(), init(), update()
205+
206+
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC})
207+
208+
209+
def intrin_wmma_store_matrix():
210+
n = 16
211+
A = tvm.placeholder((n, n), name='A', dtype='float32')
212+
BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256)
213+
C = tvm.compute((n, n), lambda i, j: A[i, j], name='C')
214+
BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256)
215+
216+
def intrin_func(ins, outs):
217+
ib = tvm.ir_builder.create()
218+
BA = ins[0]
219+
BC = outs[0]
220+
ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync',
221+
BA.data, n, n, n, BA.elem_offset // 256,
222+
BC.access_ptr('w'), n, 'row_major'))
223+
return ib.get()
224+
225+
return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC})
226+
227+
###############################################################################
228+
# Scheduling the Computation
229+
# --------------------------
230+
# To use TensorCores in TVM, we must schedule the computation into specific structure
231+
# to match the tensor intrinsic. The same as traditional GPU programs, we can also use
232+
# shared memory to boost the speed. If you have any questions about blocking and shared
233+
# memory, please refer :ref:`opt-conv-gpu`.
234+
#
235+
# In this example, each block contains 2x4 warps, and each warp calls 4x2 TensorCore
236+
# instructions. Thus, the output shape of each warp is 64x32 and each block outputs
237+
# 128x128 titles. Due to the limit of shared memory space, we only load 2 blocks (2x128x128 tiles)
238+
# one time.
239+
#
240+
# .. note::
241+
#
242+
# *Warp-level Operation*
243+
#
244+
# Note that all TensorCore instructions are warp-level instructions, which means all 32 threads
245+
# in a warp should do this instruction simultaneously. Making theadIdx.x extent=32 is one of the
246+
# easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain
247+
# TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution.
248+
# The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time.
249+
#
250+
251+
252+
# Define tiling sizes
253+
block_row_warps = 2
254+
block_col_warps = 4
255+
warp_row_tiles = 4
256+
warp_col_tiles = 2
257+
warp_size = 32
258+
chunk = 2
259+
260+
block_x = tvm.thread_axis('blockIdx.x')
261+
block_y = tvm.thread_axis('blockIdx.y')
262+
block_z = tvm.thread_axis('blockIdx.z')
263+
thread_x = tvm.thread_axis('threadIdx.x')
264+
thread_y = tvm.thread_axis('threadIdx.y')
265+
thread_z = tvm.thread_axis('threadIdx.z')
266+
267+
nc, hc, wc, oc, nnc, ooc = Conv.op.axis
268+
block_k = s[Conv].fuse(hc, wc)
269+
s[Conv].bind(block_k, block_z)
270+
nc, nci = s[Conv].split(nc, factor=warp_row_tiles)
271+
block_i, nc = s[Conv].split(nc, factor=block_row_warps)
272+
oc, oci = s[Conv].split(oc, factor=warp_col_tiles)
273+
block_j, oc = s[Conv].split(oc, factor=block_col_warps)
274+
s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc)
275+
s[Conv].bind(block_i, block_x)
276+
s[Conv].bind(block_j, block_y)
277+
s[Conv].bind(nc, thread_y)
278+
s[Conv].bind(oc, thread_z)
279+
280+
# Schedule local computation
281+
s[ConvF].compute_at(s[Conv], oc)
282+
n, h, w, o, nnf, oof = ConvF.op.axis
283+
ko, ki = s[ConvF].split(ic, factor=chunk)
284+
s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii)
285+
286+
# Move intermediate computation into each output compute tile
287+
s[AF].compute_at(s[ConvF], kw)
288+
s[WF].compute_at(s[ConvF], kw)
289+
290+
# Schedule for A's share memory
291+
s[AS].compute_at(s[ConvF], kh)
292+
n, h, w, i, nn, ii = AS.op.axis
293+
tx, xo = s[AS].split(n, nparts=block_row_warps)
294+
ty, yo = s[AS].split(xo, nparts=block_col_warps)
295+
t = s[AS].fuse(nn, ii)
296+
to, ti = s[AS].split(t, factor=warp_size)
297+
s[AS].bind(tx, thread_y)
298+
s[AS].bind(ty, thread_z)
299+
s[AS].bind(ti, thread_x)
300+
301+
# Schedule for W's share memory
302+
s[WS].compute_at(s[ConvF], kh)
303+
kh, kw, ic, o, ii, oo = WS.op.axis
304+
tx, xo = s[WS].split(o, nparts=block_row_warps)
305+
ty, yo = s[WS].split(xo, nparts=block_col_warps)
306+
t = s[WS].fuse(ii, oo)
307+
to, ti = s[WS].split(t, nparts=warp_size)
308+
s[WS].bind(tx, thread_y)
309+
s[WS].bind(ty, thread_z)
310+
s[WS].bind(to, thread_x)
311+
s[WS].vectorize(ti)
312+
print(tvm.lower(s, [A, W, Conv], simple_mode=True))
313+
314+
###############################################################################
315+
# Lowering Computation to Intrinsics
316+
# --------------------------
317+
# The last phase is to lower the computation loops down to TensorCore hardware intrinsics
318+
# by mapping the 2D convolution to tensor intrinsics
319+
#
320+
321+
s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a'))
322+
s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b'))
323+
s[Conv].tensorize(nnc, intrin_wmma_store_matrix())
324+
s[ConvF].tensorize(nnf, intrin_wmma_gemm())
325+
print(tvm.lower(s, [A, W, Conv], simple_mode=True))
326+
327+
###############################################################################
328+
# Generate CUDA Kernel
329+
# --------------------
330+
# Finally we use TVM to generate and compile the CUDA kernel, and evaluate the latency of convolution.
331+
# Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not
332+
# be able to run on our build server
333+
334+
ctx = tvm.gpu(0)
335+
if nvcc.have_tensorcore(ctx.compute_version):
336+
func = tvm.build(s, [A, W, Conv], 'cuda')
337+
a_np = np.random.uniform(size=data_shape).astype(A.dtype)
338+
w_np = np.random.uniform(size=kernel_shape).astype(W.dtype)
339+
a = tvm.nd.array(a_np, ctx)
340+
w = tvm.nd.array(w_np, ctx)
341+
c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx)
342+
evaluator = func.time_evaluator(func.entry_name, ctx, number=10)
343+
print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3))
344+
345+
###############################################################################
346+
# Summary
347+
# This tutorial demonstrates how TVM scheduling primitives can be used to
348+
# call TensorCores on specific GPUs.

0 commit comments

Comments
 (0)