Skip to content

Commit 400baf2

Browse files
authored
refactor optimize GEMM on CPU tutorial (#8825)
* refactor optimize GEMM on CPU tutorial * fix lint errors * fix more lint errors * fix typo * fix problem with redefinition of `k` add TODO and comments around loop unrolling clarify note on the array packing figure * reword general description of array packing * grap kaxis from compute definition * remove duplicate comments on unrolling
1 parent 6df070a commit 400baf2

File tree

1 file changed

+72
-61
lines changed

1 file changed

+72
-61
lines changed

tutorials/optimize/opt_gemm.py

Lines changed: 72 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@
101101
k = te.reduce_axis((0, K), "k")
102102
A = te.placeholder((M, K), name="A")
103103
B = te.placeholder((K, N), name="B")
104-
C = te.compute((M, N), lambda x, y: te.sum(A[x, k] * B[k, y], axis=k), name="C")
104+
C = te.compute((M, N), lambda m, n: te.sum(A[m, k] * B[k, n], axis=k), name="C")
105105

106106
# Default schedule
107107
s = te.create_schedule(C.op)
@@ -130,15 +130,16 @@
130130
# fill 32 * 32 * sizeof(float) which is 4KB in the cache whose total size is 32KB (L1 data cache)
131131

132132
bn = 32
133+
kfactor = 4
133134
s = te.create_schedule(C.op)
134135

135136
# Blocking by loop tiling
136-
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
137-
(k,) = s[C].op.reduce_axis
138-
ko, ki = s[C].split(k, factor=4)
137+
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
138+
(kaxis,) = s[C].op.reduce_axis
139+
ko, ki = s[C].split(kaxis, factor=kfactor)
139140

140141
# Hoist reduction domain outside the blocking loop
141-
s[C].reorder(xo, yo, ko, ki, xi, yi)
142+
s[C].reorder(mo, no, ko, ki, mi, ni)
142143

143144
func = tvm.build(s, [A, B, C], target=target, name="mmult")
144145
assert func
@@ -162,19 +163,20 @@
162163
# -------------
163164
# Another important trick is vectorization. When the memory access pattern is uniform,
164165
# the compiler can detect this pattern and pass the continuous memory to vector processor. In TVM,
165-
# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it vastly.
166+
# we can use `vectorize` interface to hint the compiler this pattern, so that we can accelerate it
167+
# vastly.
166168
#
167169
# In this tutorial, we chose to vectorize the inner loop row data since it is cache friendly.
168170

169171
s = te.create_schedule(C.op)
170-
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
171-
(k,) = s[C].op.reduce_axis
172-
ko, ki = s[C].split(k, factor=4)
172+
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
173+
(kaxis,) = s[C].op.reduce_axis
174+
ko, ki = s[C].split(kaxis, factor=kfactor)
173175

174-
s[C].reorder(xo, yo, ko, ki, xi, yi)
176+
s[C].reorder(mo, no, ko, ki, mi, ni)
175177

176178
# Vectorization
177-
s[C].vectorize(yi)
179+
s[C].vectorize(ni)
178180

179181
func = tvm.build(s, [A, B, C], target=target, name="mmult")
180182
assert func
@@ -194,20 +196,19 @@
194196
###################################################################################################
195197
# Loop Permutation
196198
# ----------------
197-
# If we look at the above IR, we can see the inner loop row data is vectorized and
198-
# B is transformed into PackedB. The traversal of PackedB is sequential now.
199-
# So we will look at the access pattern of A. In current schedule, A is accessed column by column
200-
# which is not cache friendly. If we change the nested loop order of ki and inner axes xi,
199+
# If we look at the above IR, we can see the inner loop row data is vectorized for both B and C.
200+
# Next we will look at the access pattern of A. In current schedule, A is accessed column by column
201+
# which is not cache friendly. If we change the nested loop order of ki and inner axes mi,
201202
# the access pattern for A matrix is more cache friendly.
202203

203204
s = te.create_schedule(C.op)
204-
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
205-
(k,) = s[C].op.reduce_axis
206-
ko, ki = s[C].split(k, factor=4)
205+
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
206+
(kaxis,) = s[C].op.reduce_axis
207+
ko, ki = s[C].split(kaxis, factor=kfactor)
207208

208209
# re-ordering
209-
s[C].reorder(xo, yo, ko, xi, ki, yi)
210-
s[C].vectorize(yi)
210+
s[C].reorder(mo, no, ko, mi, ki, ni)
211+
s[C].vectorize(ni)
211212

212213
func = tvm.build(s, [A, B, C], target=target, name="mmult")
213214
assert func
@@ -227,43 +228,48 @@
227228
###################################################################################################
228229
# Array Packing
229230
# -------------
230-
# Another important trick is array packing. This trick is to reorder the storage dimension of the
231-
# array to convert the continuous access pattern on certain dimension to a sequential pattern after
232-
# flattening.
231+
# Another important trick is array packing. The trick is to reorder the storage of a multi-
232+
# dimensional array so that it is accessed sequentially after it is flattened and stored in one-
233+
# dimensional memory.
233234
#
234235
# .. image:: https://github.com/dmlc/web-data/raw/main/tvm/tutorial/array-packing.png
235236
# :align: center
236237
#
238+
# NOTE: This figure is a general illustration of how array packing works.
237239

238240

239241
###################################################################################################
240-
# Just as it is shown in the figure above, after blocking the computations, we can observe the array
241-
# access pattern of B (after flattening), which is regular but discontinuous. We expect that after
242-
# some transformation we can get continuous access pattern. We can reorder a [16][16] array to
243-
# a [16/4][16][4] array, so that the access pattern of B will be sequential when grabing
244-
# the corresponding value from the packed array.
245-
#
242+
# We can use array packing to address the access pattern for B. Observe the array access pattern of
243+
# B after flattening which is not sequential as we iterate over the K dimension. We can reorder B
244+
# with dimensions [K][N] so that it has dimensions [N/bn][K][bn] where bn is the blocking factor and
245+
# also the vector size for B in the inner loop. This reorder splits N into two dimensions ---
246+
# bigN (N/bn) and littleN (bn) --- and the new dimensions [N/bn][K][bn] match the indexing of B
247+
# from outer to inner loops (no, ko, ki, ni) resulting in a sequential access pattern for B after
248+
# flattening.
249+
246250

247251
# We have to re-write the algorithm slightly.
248-
packedB = te.compute((N / bn, K, bn), lambda x, y, z: B[y, x * bn + z], name="packedB")
252+
packedB = te.compute(
253+
(N / bn, K, bn), lambda bigN, k, littleN: B[k, bigN * bn + littleN], name="packedB"
254+
)
249255
C = te.compute(
250256
(M, N),
251-
lambda x, y: te.sum(A[x, k] * packedB[y // bn, k, tvm.tir.indexmod(y, bn)], axis=k),
257+
lambda m, n: te.sum(A[m, k] * packedB[n // bn, k, tvm.tir.indexmod(n, bn)], axis=k),
252258
name="C",
253259
)
254260

255261
s = te.create_schedule(C.op)
256262

257-
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
258-
(k,) = s[C].op.reduce_axis
259-
ko, ki = s[C].split(k, factor=4)
263+
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
264+
(kaxis,) = s[C].op.reduce_axis
265+
ko, ki = s[C].split(kaxis, factor=kfactor)
260266

261-
s[C].reorder(xo, yo, ko, xi, ki, yi)
262-
s[C].vectorize(yi)
267+
s[C].reorder(mo, no, ko, mi, ki, ni)
268+
s[C].vectorize(ni)
263269

264-
x, y, z = s[packedB].op.axis
265-
s[packedB].vectorize(z)
266-
s[packedB].parallel(x)
270+
bigN, _, littleN = s[packedB].op.axis
271+
s[packedB].vectorize(littleN)
272+
s[packedB].parallel(bigN)
267273

268274
func = tvm.build(s, [A, B, C], target=target, name="mmult")
269275
assert func
@@ -293,23 +299,28 @@
293299
# Allocate write cache
294300
CC = s.cache_write(C, "global")
295301

296-
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
302+
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
297303

298-
# Write cache is computed at yo
299-
s[CC].compute_at(s[C], yo)
304+
# Write cache is computed at no
305+
s[CC].compute_at(s[C], no)
300306

301307
# New inner axes
302-
xc, yc = s[CC].op.axis
308+
mc, nc = s[CC].op.axis
309+
310+
(kaxis,) = s[CC].op.reduce_axis
311+
ko, ki = s[CC].split(kaxis, factor=kfactor)
312+
s[CC].reorder(ko, mc, ki, nc)
313+
s[CC].vectorize(nc)
303314

304-
(k,) = s[CC].op.reduce_axis
305-
ko, ki = s[CC].split(k, factor=4)
306-
s[CC].reorder(ko, xc, ki, yc)
315+
# TODO: Add separate optimization step to discuss loop unrolloing
316+
# unrolling is a loop optimization strategy which can reduce branch
317+
# prediction failures and increases the chance of concurrent execution
318+
# unroll kfactor loops
307319
s[CC].unroll(ki)
308-
s[CC].vectorize(yc)
309320

310-
x, y, z = s[packedB].op.axis
311-
s[packedB].vectorize(z)
312-
s[packedB].parallel(x)
321+
bigN, _, littleN = s[packedB].op.axis
322+
s[packedB].vectorize(littleN)
323+
s[packedB].parallel(bigN)
313324

314325
func = tvm.build(s, [A, B, C], target=target, name="mmult")
315326
assert func
@@ -335,24 +346,24 @@
335346

336347
CC = s.cache_write(C, "global")
337348

338-
xo, yo, xi, yi = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
349+
mo, no, mi, ni = s[C].tile(C.op.axis[0], C.op.axis[1], bn, bn)
339350

340-
s[CC].compute_at(s[C], yo)
351+
s[CC].compute_at(s[C], no)
341352

342-
xc, yc = s[CC].op.axis
353+
mc, nc = s[CC].op.axis
343354

344-
(k,) = s[CC].op.reduce_axis
345-
ko, ki = s[CC].split(k, factor=4)
346-
s[CC].reorder(ko, xc, ki, yc)
355+
(kaxis,) = s[CC].op.reduce_axis
356+
ko, ki = s[CC].split(kaxis, factor=kfactor)
357+
s[CC].reorder(ko, mc, ki, nc)
358+
s[CC].vectorize(nc)
347359
s[CC].unroll(ki)
348-
s[CC].vectorize(yc)
349360

350361
# parallel
351-
s[C].parallel(xo)
362+
s[C].parallel(mo)
352363

353-
x, y, z = s[packedB].op.axis
354-
s[packedB].vectorize(z)
355-
s[packedB].parallel(x)
364+
bigN, _, littleN = s[packedB].op.axis
365+
s[packedB].vectorize(littleN)
366+
s[packedB].parallel(bigN)
356367

357368
func = tvm.build(s, [A, B, C], target=target, name="mmult")
358369
assert func

0 commit comments

Comments
 (0)