Skip to content

Commit

Permalink
Add the acc16 intrinsic support (apache#3081)
Browse files Browse the repository at this point in the history
  • Loading branch information
lly-zero-one authored and wweic committed Jun 27, 2019
1 parent c3c8b23 commit c391a3f
Show file tree
Hide file tree
Showing 2 changed files with 169 additions and 0 deletions.
90 changes: 90 additions & 0 deletions tests/python/contrib/test_gemm_acc16.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=import-self, invalid-name, unused-argument, too-many-lines, len-as-condition
import tvm
import numpy as np
from topi.x86.tensor_intrin import dot_16x1x16_int8_int8_int16


def benchmark_fc_int8_acc16():
m = 128
n = 128
k = 128

X = tvm.placeholder((m, k), name='X', dtype="uint8")
W = tvm.placeholder((n, k), name='W', dtype="int8")

peak = 512/16*2*2*2
gops_per_mm = 2*n*m*k
print("Peak {} Gops/s \n".format(peak))

def verify(target="llvm -mcpu=skylake-avx512"):
if not tvm.module.enabled(target):
print("skip because %s is not enabled..." % target)
return

ctx = tvm.context(target, 0)
X = tvm.placeholder((m, k), name='X', dtype="uint8")
W = tvm.placeholder((n, k), name='W', dtype="int8")
pc = dot_16x1x16_int8_int8_int16()
ak = tvm.reduce_axis((0, k), name='k')

packedW = tvm.placeholder((n/128, 128*(k/2), 2), name='packedW', dtype="int8")
t_fc = tvm.compute((m, n), lambda i, j: tvm.sum(X[i, ak].astype("int16") * packedW[j/128, (ak/2)*128+j%128, ak%2].astype("int16"), axis=ak), name="F")

t_sch = tvm.create_schedule(t_fc.op)
a_x, a_y = t_fc.op.axis
a_k, = t_fc.op.reduce_axis

a_yo, a_yi = t_sch[t_fc].split(a_y, factor=128)
a_ko, a_ki = t_sch[t_fc].split(a_k, factor=2)

a_xo, a_xi = t_sch[t_fc].split(a_x, factor=128)
a_koo, a_koi = t_sch[t_fc].split(a_ko, factor=32)
t_sch[t_fc].reorder(a_yo, a_xo, a_koo, a_xi, a_koi, a_yi, a_ki)

t_sch[t_fc].tensorize(a_yi, pc)
# print(tvm.lower(t_sch, [X, packedW, t_fc], simple_mode=True))
t_func = tvm.build(t_sch, [X, packedW, t_fc], target, name="intrinsic")
t_evaluator = t_func.time_evaluator(t_func.entry_name, ctx, number=10)

# generate the plain data
a_ = np.random.uniform(1, 10, size=(m, k)).astype("uint8")
b_ = np.random.uniform(1, 10, size=(n, k)).astype("int8")

packW = np.random.uniform(1, 10, size=(n/128, 128*(k/2), 2)).astype("int8")
# This occurs in pre_compute stage
for r_idx in range(n/128):
for s_idx in range(128*(k/2)):
for t_idx in range(2):
packW[r_idx][s_idx][t_idx] = b_[r_idx*128+s_idx%128][s_idx/128*2+t_idx]

x = tvm.nd.array(a_, ctx)
w = tvm.nd.array(packW, ctx)
y = tvm.nd.array(np.zeros((m, n), dtype="int16"), ctx)

result = t_evaluator(x, w, y)
gops_per_sec = gops_per_mm/result.mean/1e9
tvm.testing.assert_allclose(
y.asnumpy(), np.dot(a_, b_.T), rtol=1e-5)
print('Tensorization: running time: {:.3f} ms, {:.2f} Gops/s, effiency: {:.2f}.'.format(result.mean*1000, gops_per_sec, gops_per_sec/peak))
t_func.export_library("gemm_tensorize.o")

verify()

if __name__ == "__main__":
benchmark_fc_int8_acc16()
79 changes: 79 additions & 0 deletions topi/python/topi/x86/tensor_intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,82 @@ def _instr(index):

with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})


def dot_16x1x16_int8_int8_int16():
"""
Int8 dot product by every 2 elements using AVX2 Skylake instructions.
This function takes two arrays of int8 datatype -- data[2] and
kernel[4][32][2] -- and computes a dot product of data[2] with every
2 elements of kernels, resulting in output[4][32] of int16 datatype.
The pseudo code is as follows.
.. code-block:: c
void dot_16x1x16_int8_int8_int16(int8 data[2], int8 kernel[32*4][2],
int16 output[32*4]){
for (int i = 0; i< 4; i++){
for (int j = 0; j < 32; j++){
out[i][i] = 0;
for (int k = 0; k < 2; k++){
out[i][j][k] += data[k] * kernel[i][j][k]
}
}
}
}
Physically, the kernel array sits in four AVX512 vector registers and
the data[2] is broadcasted to another AVX512 vector register. This
function returns a TensorIntrin that can be used to tensorize
a schedule.
Returns
-------
intrin : TensorIntrin
The Skylake int8 TensorIntrin that can be used in tensorizing schedule
"""

num_int8_elements = 2 # 2 int8 elements in int32
data = tvm.placeholder((num_int8_elements,), dtype='uint8', name='data')
kernel = tvm.placeholder((128, num_int8_elements), dtype='int8', name='kernel')
k = tvm.reduce_axis((0, num_int8_elements), name='k')
C = tvm.compute((128, ),
lambda i: tvm.sum(data[k].astype('int16') *
kernel[i, k].astype('int16'),
axis=k),
name="C")

a_buffer = tvm.decl_buffer(data.shape, dtype='uint8', name="a_buffer",
offset_factor=1,
strides=[1])
b_buffer = tvm.decl_buffer(kernel.shape, dtype='int8', name="b_buffer",
offset_factor=1)
# strides=[tvm.var('ldw'), 1, 1])

def _intrin_func(ins, outs):
def _instr(index):
ib = tvm.ir_builder.create()
if index == 1:
for i in range(4):
ib.emit(outs[0].vstore([i*32], tvm.const(0, 'int16x32')))
return ib.get()

a_int8 = ins[0].vload([0], "uint8x2")
re_int16 = tvm.call_pure_intrin('int16', 'reinterpret', a_int8)
vec_ai16 = re_int16.astype('int16x32')
vec_a = tvm.call_pure_intrin('int8x64', 'reinterpret', vec_ai16)

for i in range(4):
vec_b = ins[1].vload([i*32, 0], "int8x64")
pair_reduction = tvm.call_llvm_intrin('int16x32',
'llvm.x86.avx512.pmaddubs.w.512',
tvm.const(0, 'uint32'),
vec_a, vec_b)
if index == 0:
ib.emit(outs[0].vstore([i*32], pair_reduction))
else:
ib.emit(outs[0].vstore([i*32], pair_reduction + outs[0].vload([i*32],
'int16x32')))
return ib.get()

# body, reset, update
return _instr(0), _instr(1), _instr(2)

with tvm.build_config(offset_factor=1, partition_const_loop=True):
return tvm.decl_tensor_intrin(C.op, _intrin_func, binds={data:a_buffer, kernel:b_buffer})

0 comments on commit c391a3f

Please sign in to comment.