|
| 1 | +# Licensed to the Apache Software Foundation (ASF) under one |
| 2 | +# or more contributor license agreements. See the NOTICE file |
| 3 | +# distributed with this work for additional information |
| 4 | +# regarding copyright ownership. The ASF licenses this file |
| 5 | +# to you under the Apache License, Version 2.0 (the |
| 6 | +# "License"); you may not use this file except in compliance |
| 7 | +# with the License. You may obtain a copy of the License at |
| 8 | +# |
| 9 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | +# |
| 11 | +# Unless required by applicable law or agreed to in writing, |
| 12 | +# software distributed under the License is distributed on an |
| 13 | +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| 14 | +# KIND, either express or implied. See the License for the |
| 15 | +# specific language governing permissions and limitations |
| 16 | +# under the License. |
| 17 | +""" |
| 18 | +.. _opt-conv-tensorcore: |
| 19 | +
|
| 20 | +How to optimize convolution using TensorCores |
| 21 | +================================== |
| 22 | +**Author**: `Siyuan Feng <https://github.com/Hzfengsy>`_ |
| 23 | +
|
| 24 | +In this tutorial, we will demonstrate how to write a high performance convolution |
| 25 | +schedule using TensorCores in TVM. In this example, we assume the input to |
| 26 | +convolution has a large batch. We strongly recommend covering the :ref:`opt-conv-gpu` tutorial first. |
| 27 | +
|
| 28 | +""" |
| 29 | + |
| 30 | +################################################################ |
| 31 | +# TensorCore Introduction |
| 32 | +# ------------------------- |
| 33 | +# Each Tensor Core provides a 4x4x4 matrix processing array that operates |
| 34 | +# :code:`D = A * B + C`, where A, B, C and D are 4x4 matrices as Figure shows. |
| 35 | +# The matrix multiplication inputs A and B are FP16 matrices, while the accumulation |
| 36 | +# matrices C and D may be FP16 or FP32 matrices. |
| 37 | +# |
| 38 | +# However, CUDA programmers can only use warp-level primitive |
| 39 | +# :code:`wmma::mma_sync(acc_frag, a_frag, b_frag, acc_frag)` to perform |
| 40 | +# 16x16x16 half-precision matrix multiplication on tensor cores. Before invoking |
| 41 | +# the matrix multiplication, programmers must load data from memory into registers |
| 42 | +# with primitive :code:`wmma::load_matrix_sync`, explicitly. The NVCC compiler translates |
| 43 | +# that primitive into multiple memory load instructions. At run time, every thread loads |
| 44 | +# 16 elements from matrix A and 16 elements from B. |
| 45 | + |
| 46 | +################################################################ |
| 47 | +# Preparation and Algorithm |
| 48 | +# -------------------------- |
| 49 | +# We use the fixed size for input tensors with 256 channels and 14 x 14 dimensions. |
| 50 | +# The batch size is 256. Convolution filters contain 512 filters of size 3 x 3. |
| 51 | +# We use stride size 1 and padding size 1 for the convolution. In the example, we use |
| 52 | +# NHWCnc memory layout.The following code defines the convolution algorithm in TVM. |
| 53 | + |
| 54 | +import tvm |
| 55 | +import numpy as np |
| 56 | +from tvm.contrib import nvcc |
| 57 | + |
| 58 | +# The sizes of inputs and filters |
| 59 | +batch_size = 256 |
| 60 | +height = 14 |
| 61 | +width = 14 |
| 62 | +in_channels = 256 |
| 63 | +out_channels = 512 |
| 64 | +kernel_h = 3 |
| 65 | +kernel_w = 3 |
| 66 | +pad_h = 1 |
| 67 | +pad_w = 1 |
| 68 | +stride_h = 1 |
| 69 | +stride_w = 1 |
| 70 | + |
| 71 | +# TensorCore shape |
| 72 | +block_size = 16 |
| 73 | + |
| 74 | +assert (batch_size % block_size == 0) |
| 75 | +assert (in_channels % block_size == 0) |
| 76 | +assert (out_channels % block_size == 0) |
| 77 | + |
| 78 | +# Input feature map: (N, H, W, IC, n, ic) |
| 79 | +data_shape = (batch_size // block_size, |
| 80 | + height, |
| 81 | + width, |
| 82 | + in_channels // block_size, |
| 83 | + block_size, |
| 84 | + block_size) |
| 85 | +# Kernel: (H, W, IC, OC, ic, oc) |
| 86 | +kernel_shape = (kernel_h, |
| 87 | + kernel_w, |
| 88 | + in_channels // block_size, |
| 89 | + out_channels // block_size, |
| 90 | + block_size, |
| 91 | + block_size) |
| 92 | +# Output feature map: (N, H, W, OC, n, oc) |
| 93 | +output_shape = (batch_size // block_size, |
| 94 | + height, |
| 95 | + width, |
| 96 | + out_channels // block_size, |
| 97 | + block_size, |
| 98 | + block_size) |
| 99 | + |
| 100 | +# Reduction axes |
| 101 | +kh = tvm.reduce_axis((0, kernel_h), name='kh') |
| 102 | +kw = tvm.reduce_axis((0, kernel_w), name='kw') |
| 103 | +ic = tvm.reduce_axis((0, in_channels // block_size), name='ic') |
| 104 | +ii = tvm.reduce_axis((0, block_size), name='ii') |
| 105 | + |
| 106 | +# Algorithm |
| 107 | +A = tvm.placeholder(data_shape, name='A', dtype="float16") |
| 108 | +W = tvm.placeholder(kernel_shape, name='W', dtype="float16") |
| 109 | +Apad = tvm.compute( |
| 110 | + (batch_size // block_size, height + 2 * pad_h, width + 2 * pad_w, in_channels // block_size, block_size, |
| 111 | + block_size), |
| 112 | + lambda n, h, w, i, nn, ii: tvm.if_then_else( |
| 113 | + tvm.all(h >= pad_h, h - pad_h < height, |
| 114 | + w >= pad_w, w - pad_w < width), |
| 115 | + A[n, h - pad_h, w - pad_w, i, nn, ii], tvm.const(0., "float16")), |
| 116 | + name='Apad') |
| 117 | +Conv = tvm.compute(output_shape, |
| 118 | + lambda n, h, w, o, nn, oo: tvm.sum( |
| 119 | + Apad[n, h * stride_h + kh, w * stride_w + kw, ic, nn, ii].astype("float32") * |
| 120 | + W[kh, kw, ic, o, ii, oo].astype("float32"), |
| 121 | + axis=[ic, kh, kw, ii]), |
| 122 | + name="Conv") |
| 123 | + |
| 124 | +s = tvm.create_schedule(Conv.op) |
| 125 | +s[Apad].compute_inline() |
| 126 | + |
| 127 | +############################################################################### |
| 128 | +# Memory Scope |
| 129 | +# ---------------- |
| 130 | +# |
| 131 | +# In traditional GPU schedule, we have global, shared and local memory scope. |
| 132 | +# To support TensorCores, we add another three special memory scope: :code:`wmma.matrix_a`, |
| 133 | +# :code:`wmma.matrix_b` and :code:`wmma.accumulator`. On hardware, all fragments scope |
| 134 | +# stores at the on-chip registers level, the same place with local memory. |
| 135 | + |
| 136 | +# Designate the memory hierarchy |
| 137 | +AS = s.cache_read(Apad, 'shared', [Conv]) |
| 138 | +WS = s.cache_read(W, 'shared', [Conv]) |
| 139 | +AF = s.cache_read(AS, 'wmma.matrix_a', [Conv]) |
| 140 | +WF = s.cache_read(WS, 'wmma.matrix_b', [Conv]) |
| 141 | +ConvF = s.cache_write(Conv, 'wmma.accumulator') |
| 142 | + |
| 143 | +############################################################################### |
| 144 | +# Define Tensor Intrinsic |
| 145 | +# In fact, TensorCore is a special hardware operation. So, we can just use tensorize |
| 146 | +# to replace a unit of computation with the TensorCore instruction. The first thing is |
| 147 | +# that we need to define tensor intrinsic. |
| 148 | +# |
| 149 | +# There are four basic operation in TensorCore: :code:`fill_fragment`, :code:`load_matrix`, |
| 150 | +# :code:`mma_sync` and :code:`store_matrix`. Since :code:`fill_fragment` and :code:`mma_sync` |
| 151 | +# are both used in matrix multiplication, so we can just write following three intrinsics. |
| 152 | + |
| 153 | +def intrin_wmma_load_matrix(scope): |
| 154 | + n = 16 |
| 155 | + A = tvm.placeholder((n, n), name='A', dtype='float16') |
| 156 | + BA = tvm.decl_buffer(A.shape, A.dtype, scope='shared', data_alignment=32, offset_factor=256) |
| 157 | + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') |
| 158 | + BC = tvm.decl_buffer(C.shape, C.dtype, scope=scope, data_alignment=32, offset_factor=256) |
| 159 | + |
| 160 | + def intrin_func(ins, outs): |
| 161 | + ib = tvm.ir_builder.create() |
| 162 | + |
| 163 | + BA = ins[0] |
| 164 | + BC = outs[0] |
| 165 | + ib.emit(tvm.call_intrin('handle', 'tvm_load_matrix_sync', |
| 166 | + BC.data, n, n, n, BC.elem_offset // 256, |
| 167 | + BA.access_ptr('r'), n, 'row_major')) |
| 168 | + return ib.get() |
| 169 | + |
| 170 | + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) |
| 171 | + |
| 172 | + |
| 173 | +def intrin_wmma_gemm(): |
| 174 | + n = 16 |
| 175 | + A = tvm.placeholder((n, n), name='A', dtype='float16') |
| 176 | + B = tvm.placeholder((n, n), name='B', dtype='float16') |
| 177 | + k = tvm.reduce_axis((0, n), name="k") |
| 178 | + C = tvm.compute((n, n), |
| 179 | + lambda ii, jj: |
| 180 | + tvm.sum(A[ii, k].astype('float') * B[k, jj].astype('float'), axis=k), |
| 181 | + name='C') |
| 182 | + BA = tvm.decl_buffer(A.shape, A.dtype, name='BA', scope='wmma.matrix_a', data_alignment=32, offset_factor=256) |
| 183 | + BB = tvm.decl_buffer(B.shape, B.dtype, name='BB', scope='wmma.matrix_b', data_alignment=32, offset_factor=256) |
| 184 | + BC = tvm.decl_buffer(C.shape, C.dtype, name='BC', scope='wmma.accumulator', data_alignment=32, offset_factor=256) |
| 185 | + |
| 186 | + def intrin_func(ins, outs): |
| 187 | + BA, BB = ins |
| 188 | + BC, = outs |
| 189 | + |
| 190 | + def init(): |
| 191 | + ib = tvm.ir_builder.create() |
| 192 | + ib.emit(tvm.call_intrin('handle', 'tvm_fill_fragment', BC.data, n, n, n, BC.elem_offset // 256, 0.0)) |
| 193 | + return ib.get() |
| 194 | + |
| 195 | + def update(): |
| 196 | + ib = tvm.ir_builder.create() |
| 197 | + ib.emit(tvm.call_intrin('handle', 'tvm_mma_sync', |
| 198 | + BC.data, BC.elem_offset // 256, |
| 199 | + BA.data, BA.elem_offset // 256, |
| 200 | + BB.data, BB.elem_offset // 256, |
| 201 | + BC.data, BC.elem_offset // 256)) |
| 202 | + return ib.get() |
| 203 | + |
| 204 | + return update(), init(), update() |
| 205 | + |
| 206 | + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, B: BB, C: BC}) |
| 207 | + |
| 208 | + |
| 209 | +def intrin_wmma_store_matrix(): |
| 210 | + n = 16 |
| 211 | + A = tvm.placeholder((n, n), name='A', dtype='float32') |
| 212 | + BA = tvm.decl_buffer(A.shape, A.dtype, scope='wmma.accumulator', data_alignment=32, offset_factor=256) |
| 213 | + C = tvm.compute((n, n), lambda i, j: A[i, j], name='C') |
| 214 | + BC = tvm.decl_buffer(C.shape, C.dtype, scope='global', data_alignment=32, offset_factor=256) |
| 215 | + |
| 216 | + def intrin_func(ins, outs): |
| 217 | + ib = tvm.ir_builder.create() |
| 218 | + BA = ins[0] |
| 219 | + BC = outs[0] |
| 220 | + ib.emit(tvm.call_intrin('handle', 'tvm_store_matrix_sync', |
| 221 | + BA.data, n, n, n, BA.elem_offset // 256, |
| 222 | + BC.access_ptr('w'), n, 'row_major')) |
| 223 | + return ib.get() |
| 224 | + |
| 225 | + return tvm.decl_tensor_intrin(C.op, intrin_func, binds={A: BA, C: BC}) |
| 226 | + |
| 227 | +############################################################################### |
| 228 | +# Scheduling the Computation |
| 229 | +# -------------------------- |
| 230 | +# To use TensorCores in TVM, we must schedule the computation into specific structure |
| 231 | +# to match the tensor intrinsic. The same as traditional GPU programs, we can also use |
| 232 | +# shared memory to boost the speed. If you have any questions about blocking and shared |
| 233 | +# memory, please refer :ref:`opt-conv-gpu`. |
| 234 | +# |
| 235 | +# In this example, each block contains 2x4 warps, and each warp calls 4x2 TensorCore |
| 236 | +# instructions. Thus, the output shape of each warp is 64x32 and each block outputs |
| 237 | +# 128x128 titles. Due to the limit of shared memory space, we only load 2 blocks (2x128x128 tiles) |
| 238 | +# one time. |
| 239 | +# |
| 240 | +# .. note:: |
| 241 | +# |
| 242 | +# *Warp-level Operation* |
| 243 | +# |
| 244 | +# Note that all TensorCore instructions are warp-level instructions, which means all 32 threads |
| 245 | +# in a warp should do this instruction simultaneously. Making theadIdx.x extent=32 is one of the |
| 246 | +# easiest way to solve this. Then We can bind threadIdx.x to any loops except those contain |
| 247 | +# TensorCore intrinsics directly or indirectly. Also note that it is not the unique solution. |
| 248 | +# The only thing we should do is to make sure all threads in a warp can call TensorCore at the same time. |
| 249 | +# |
| 250 | + |
| 251 | + |
| 252 | +# Define tiling sizes |
| 253 | +block_row_warps = 2 |
| 254 | +block_col_warps = 4 |
| 255 | +warp_row_tiles = 4 |
| 256 | +warp_col_tiles = 2 |
| 257 | +warp_size = 32 |
| 258 | +chunk = 2 |
| 259 | + |
| 260 | +block_x = tvm.thread_axis('blockIdx.x') |
| 261 | +block_y = tvm.thread_axis('blockIdx.y') |
| 262 | +block_z = tvm.thread_axis('blockIdx.z') |
| 263 | +thread_x = tvm.thread_axis('threadIdx.x') |
| 264 | +thread_y = tvm.thread_axis('threadIdx.y') |
| 265 | +thread_z = tvm.thread_axis('threadIdx.z') |
| 266 | + |
| 267 | +nc, hc, wc, oc, nnc, ooc = Conv.op.axis |
| 268 | +block_k = s[Conv].fuse(hc, wc) |
| 269 | +s[Conv].bind(block_k, block_z) |
| 270 | +nc, nci = s[Conv].split(nc, factor=warp_row_tiles) |
| 271 | +block_i, nc = s[Conv].split(nc, factor=block_row_warps) |
| 272 | +oc, oci = s[Conv].split(oc, factor=warp_col_tiles) |
| 273 | +block_j, oc = s[Conv].split(oc, factor=block_col_warps) |
| 274 | +s[Conv].reorder(block_k, block_i, block_j, nc, oc, nci, oci, nnc, ooc) |
| 275 | +s[Conv].bind(block_i, block_x) |
| 276 | +s[Conv].bind(block_j, block_y) |
| 277 | +s[Conv].bind(nc, thread_y) |
| 278 | +s[Conv].bind(oc, thread_z) |
| 279 | + |
| 280 | +# Schedule local computation |
| 281 | +s[ConvF].compute_at(s[Conv], oc) |
| 282 | +n, h, w, o, nnf, oof = ConvF.op.axis |
| 283 | +ko, ki = s[ConvF].split(ic, factor=chunk) |
| 284 | +s[ConvF].reorder(ko, kh, ki, kw, n, o, nnf, oof, ii) |
| 285 | + |
| 286 | +# Move intermediate computation into each output compute tile |
| 287 | +s[AF].compute_at(s[ConvF], kw) |
| 288 | +s[WF].compute_at(s[ConvF], kw) |
| 289 | + |
| 290 | +# Schedule for A's share memory |
| 291 | +s[AS].compute_at(s[ConvF], kh) |
| 292 | +n, h, w, i, nn, ii = AS.op.axis |
| 293 | +tx, xo = s[AS].split(n, nparts=block_row_warps) |
| 294 | +ty, yo = s[AS].split(xo, nparts=block_col_warps) |
| 295 | +t = s[AS].fuse(nn, ii) |
| 296 | +to, ti = s[AS].split(t, factor=warp_size) |
| 297 | +s[AS].bind(tx, thread_y) |
| 298 | +s[AS].bind(ty, thread_z) |
| 299 | +s[AS].bind(ti, thread_x) |
| 300 | + |
| 301 | +# Schedule for W's share memory |
| 302 | +s[WS].compute_at(s[ConvF], kh) |
| 303 | +kh, kw, ic, o, ii, oo = WS.op.axis |
| 304 | +tx, xo = s[WS].split(o, nparts=block_row_warps) |
| 305 | +ty, yo = s[WS].split(xo, nparts=block_col_warps) |
| 306 | +t = s[WS].fuse(ii, oo) |
| 307 | +to, ti = s[WS].split(t, nparts=warp_size) |
| 308 | +s[WS].bind(tx, thread_y) |
| 309 | +s[WS].bind(ty, thread_z) |
| 310 | +s[WS].bind(to, thread_x) |
| 311 | +s[WS].vectorize(ti) |
| 312 | +print(tvm.lower(s, [A, W, Conv], simple_mode=True)) |
| 313 | + |
| 314 | +############################################################################### |
| 315 | +# Lowering Computation to Intrinsics |
| 316 | +# -------------------------- |
| 317 | +# The last phase is to lower the computation loops down to TensorCore hardware intrinsics |
| 318 | +# by mapping the 2D convolution to tensor intrinsics |
| 319 | +# |
| 320 | + |
| 321 | +s[AF].tensorize(AF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_a')) |
| 322 | +s[WF].tensorize(WF.op.axis[-2], intrin_wmma_load_matrix('wmma.matrix_b')) |
| 323 | +s[Conv].tensorize(nnc, intrin_wmma_store_matrix()) |
| 324 | +s[ConvF].tensorize(nnf, intrin_wmma_gemm()) |
| 325 | +print(tvm.lower(s, [A, W, Conv], simple_mode=True)) |
| 326 | + |
| 327 | +############################################################################### |
| 328 | +# Generate CUDA Kernel |
| 329 | +# -------------------- |
| 330 | +# Finally we use TVM to generate and compile the CUDA kernel, and evaluate the latency of convolution. |
| 331 | +# Since TensorCores are only supported in NVIDIA GPU with Compute Capability 7.0 or higher, it may not |
| 332 | +# be able to run on our build server |
| 333 | + |
| 334 | +ctx = tvm.gpu(0) |
| 335 | +if nvcc.have_tensorcore(ctx.compute_version): |
| 336 | + func = tvm.build(s, [A, W, Conv], 'cuda') |
| 337 | + a_np = np.random.uniform(size=data_shape).astype(A.dtype) |
| 338 | + w_np = np.random.uniform(size=kernel_shape).astype(W.dtype) |
| 339 | + a = tvm.nd.array(a_np, ctx) |
| 340 | + w = tvm.nd.array(w_np, ctx) |
| 341 | + c = tvm.nd.array(np.zeros(output_shape, dtype=Conv.dtype), ctx) |
| 342 | + evaluator = func.time_evaluator(func.entry_name, ctx, number=10) |
| 343 | + print('conv2d with tensor core: %f ms' % (evaluator(a, w, c).mean * 1e3)) |
| 344 | + |
| 345 | +############################################################################### |
| 346 | +# Summary |
| 347 | +# This tutorial demonstrates how TVM scheduling primitives can be used to |
| 348 | +# call TensorCores on specific GPUs. |
0 commit comments