From 8fd794a66a3bea14f67e137c5d8e1d421502cda3 Mon Sep 17 00:00:00 2001 From: Giuseppe Rossini Date: Tue, 29 Sep 2020 14:50:05 +0100 Subject: [PATCH] Fixing black linting and address comments Change-Id: I857b28b6f9b23307d8c1eebc509de6ad2783c756 --- python/tvm/topi/arm_cpu/arm_utils.py | 2 +- python/tvm/topi/arm_cpu/conv2d_gemm.py | 12 +-- python/tvm/topi/arm_cpu/tensor_intrin.py | 111 +++++++++++++++++------ 3 files changed, 88 insertions(+), 37 deletions(-) diff --git a/python/tvm/topi/arm_cpu/arm_utils.py b/python/tvm/topi/arm_cpu/arm_utils.py index fb47b29d6f6a..7e0f566b96f4 100644 --- a/python/tvm/topi/arm_cpu/arm_utils.py +++ b/python/tvm/topi/arm_cpu/arm_utils.py @@ -22,7 +22,7 @@ def get_arch_version(target_mattr): - """ Parse the LLVM target -mattr, and return + """Parse the LLVM target -mattr, and return the architecture version in a decimal representation (e.g., if -mattr=v8.4a, return 8.4) """ diff --git a/python/tvm/topi/arm_cpu/conv2d_gemm.py b/python/tvm/topi/arm_cpu/conv2d_gemm.py index 6896fd03289a..b40fb89b5d33 100644 --- a/python/tvm/topi/arm_cpu/conv2d_gemm.py +++ b/python/tvm/topi/arm_cpu/conv2d_gemm.py @@ -26,8 +26,8 @@ from .tensor_intrin import ( gemm_quantized, gemm_quantized_impl, - mmla_4x4_int8_int8_int32, - mmla_16x4_int8_int8_int32, + gemm_acc_4x4_int8_int8_int32, + gemm_acc_nx16_int8_int8_int32, ) from .arm_utils import is_aarch64_arm, is_dotprod_available @@ -272,7 +272,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): k = C_interleaved.op.reduce_axis[0] _, M, N = C.shape if is_dotprod_available(): - mmla = mmla_4x4_int8_int8_int32(in_type) + gemm_acc = gemm_acc_4x4_int8_int8_int32(in_type) xi_outer, yi_outer, xi_inner, yi_inner = s[C_interleaved].tile( xi, yi, x_factor=8, y_factor=4 ) @@ -289,7 +289,7 @@ def schedule_conv2d_gemm_interleaved(cfg, s, out, final_out): yi_inner, k_inner, ) - s[C_interleaved].tensorize(xi_inner_inner, mmla) + s[C_interleaved].tensorize(xi_inner_inner, gemm_acc) s[C_interleaved].unroll(xi_inner_outer) elif is_aarch64_arm(): @@ -327,9 +327,9 @@ def schedule_conv2d_gemm_native(cfg, s, out, final_out): k_outer, k_inner = s[C].split(k, 16) x_outer, y_outer, x_inner, y_inner = s[C].tile(x, y, x_factor=4, y_factor=16) s[C].reorder(b, x_outer, y_outer, k_outer, x_inner, y_inner, k_inner) - mmla = mmla_16x4_int8_int8_int32(in_type, rows=1) + gemm_acc = gemm_acc_nx16_int8_int8_int32(in_type, rows=1) s[C].unroll(x_inner) - s[C].tensorize(y_inner, mmla) + s[C].tensorize(y_inner, gemm_acc) s[C].parallel(x_outer) # Input transform diff --git a/python/tvm/topi/arm_cpu/tensor_intrin.py b/python/tvm/topi/arm_cpu/tensor_intrin.py index 79b9bf8fd732..9ed4c591da0f 100644 --- a/python/tvm/topi/arm_cpu/tensor_intrin.py +++ b/python/tvm/topi/arm_cpu/tensor_intrin.py @@ -622,28 +622,28 @@ def select_word(vec, lane, dtype_vec): return vec_int8_broadcast -def mmla_4x4_int8_int8_int32(dtype): +def gemm_acc_4x4_int8_int8_int32(dtype): """ - Int8 4x4 matrix multiplication using sdot/udot instructions - This function takes two arrays of int8 datatype -- A[4][4] and - B[4][4] and produces a 4x4 matrix which is equal to A*B + Int8 4x4 matrix multiplication and accumulation using sdot/udot + instructions. This function takes two arrays of int8 datatype + -- A[4][4] and B[4][4] and produces a 4x4 matrix + which is equal to A*B. + The pseudo code is as follows. .. code-block:: c - void mmla_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 output[4][4]){ + void gemm_acc_4x4_int8_int8_int32(int8 A[4][4], int8 B[4][4], int32 C[4][4]){ for (int i = 0; i < 4; i++){ for (int j = 0; i < 4; i++){ - out[i][j] = 0; for (int k = 0; k < 4; k++){ - out[i][j] += A[i][k] * B[j][k] + C[i][j] += A[i][k] * B[j][k] } } } Notes: * The rows of matrix B are transposed - * Matrix A is interleaved This function returns a TensorIntrin that can be used to tensorize a schedule. Parameters @@ -656,22 +656,25 @@ def mmla_4x4_int8_int8_int32(dtype): intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ - data = te.placeholder((te.var("rows"), 4), dtype, name="data") - kernel = te.placeholder((4, 4), dtype, name="kernel") + # This needs to be a variable number of "rows" since TVM + # "thinks" I only need to compute one row because of + # padding + A = te.placeholder((te.var("rows"), 4), dtype, name="data") + B = te.placeholder((4, 4), dtype, name="kernel") dtype_vec = dtype + "x16" k = te.reduce_axis((0, 4), name="k") C = te.compute( (te.var("rows"), 4), - lambda i, j: te.sum(data[i, k].astype("int32") * kernel[j, k].astype("int32"), axis=k), + lambda i, j: te.sum(A[i, k].astype("int32") * B[j, k].astype("int32"), axis=k), name="C", ) aa_buffer = tvm.tir.decl_buffer( - data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1] + A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1] ) bb_buffer = tvm.tir.decl_buffer( - kernel.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1] + B.shape, dtype, name="bb_buffer", offset_factor=1, strides=[te.var("sb"), 1] ) cc_buffer = tvm.tir.decl_buffer( C.shape, dtype="int32", name="cc_buffer", offset_factor=1, strides=[te.var("sc"), 1] @@ -686,16 +689,49 @@ def _instr(index): for i in range(0, 4): ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x4"))) return ib.get() - + # Load all the elements of tile A. + # vec_a = [a, b, c, d, + # e, f, g, h, + # i, l, m, n, + # o, p, q, r,]; vec_a = ins[0].vload([0, 0], dtype_vec) + + # Replicate 4 times the i-th row of A. For instance, + # vec_a[0] = [a, b, c, d, + # a, b, c, d, + # a, b, c, d, + # a, b, c, d,]; vec_aa = [select_word(vec_a, i, dtype_vec) for i in range(0, 4)] + + # Load all the elements of B. Remember that B + # is transposed: + # vec_b = [0, 4, 8, 12, + # 1, 5, 9, 13, + # 2, 6, 10, 14, + # 3, 7, 11, 15,]; vec_b = ins[1].vload([0, 0], dtype_vec) # Execute the dot product for i in range(0, 4): vec_c = outs[0].vload([i, 0], "int32x4") + # Compute the product between the i-th row of A + # and all the rows of B. Remember that sdot/udot + # subdive the input vectors in 16 elements + # and then take the dot product among each group. + # The result is stored in a int32x4 register + # + # For instance, for i=0, we have: + # sdot(vec_aa[0], vec_b) = [a*0+b*4+c*8+d*12, + # a*1+b*5+c*9+d*13, + # a*2+b*6+c*10+d*14, + # a*3+b*7+c*11+d*15] vdot = tvm.tir.call_llvm_intrin( - "int32x4", llvm_intrin, tvm.tir.const(3, "uint32"), vec_c, vec_b, vec_aa[i], + "int32x4", + llvm_intrin, + tvm.tir.const(3, "uint32"), + vec_c, + vec_b, + vec_aa[i], ) # Store the result @@ -710,33 +746,36 @@ def _instr(index): return te.decl_tensor_intrin( C.op, _intrin_func, - binds={data: aa_buffer, kernel: bb_buffer, C: cc_buffer}, + binds={A: aa_buffer, B: bb_buffer, C: cc_buffer}, default_buffer_params=buffer_params, ) -def mmla_16x4_int8_int8_int32(dtype, rows): +def gemm_acc_nx16_int8_int8_int32(dtype, rows): """ - Int8 16x4 matrix multiplication using sdot/udot instructions + Int8 16x4 matrix multiplication and accumulation using sdot/udot instructions This function takes two arrays of int8 datatype -- A[rows][4] and B[4][16] and produces a rowsx16 matrix which is equal to A*B The pseudo code is as follows. .. code-block:: c - void mmla_16x4_int8_int8_int32(int8 A[rows][4], int8 B[4][16], int32 output[rows][16]){ + void mmla_16x4_int8_int8_int32(int8 A[rows][16], int8 B[4][16][4], int32 output[rows][16]){ for (int i = 0; i < rows; i++){ for (int j = 0; i < 16; i++){ - out[i][j] = 0; - for (int k = 0; k < 4; k++){ - out[i][j] += A[i][k] * B[j][k] + for (int k = 0; k < 16; k++){ + out[i][j] += A[i][k] * B[k//4][j][k%4] } + } } } Notes: * The rows of matrix B are transposed - * A is not interleaved, but used in its native form + * The tile size of B is 16x4. Since the reduction variable k moves between 0 and 16 + we need 4 tiles of B to compute a single row of the output. The first 4 values of + k will be fetched from B[0][j][k], the second batch of 4 from B[1][j][k] and so on + This function returns a TensorIntrin that can be used to tensorize a schedule. Parameters @@ -749,24 +788,24 @@ def mmla_16x4_int8_int8_int32(dtype, rows): intrin : TensorIntrin The Arm TensorIntrin that can be used in tensorizing schedule """ - data = te.placeholder((rows, 16), dtype, name="data") - kernel = te.placeholder((4, 16, 4), dtype, name="kernel") + A = te.placeholder((rows, 16), dtype, name="data") + B = te.placeholder((4, 16, 4), dtype, name="kernel") dtype_vec = dtype + "x16" idxm = tvm.tir.indexmod k = te.reduce_axis((0, 16), name="k") C = te.compute( (rows, 16), lambda i, j: te.sum( - data[i, k].astype("int32") * kernel[k // 4, j, idxm(k, 4)].astype("int32"), axis=k + A[i, k].astype("int32") * B[k // 4, j, idxm(k, 4)].astype("int32"), axis=k ), name="C", ) aa_buffer = tvm.tir.decl_buffer( - data.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1] + A.shape, dtype, name="aa_buffer", offset_factor=1, strides=[te.var("sa"), 1] ) bb_buffer = tvm.tir.decl_buffer( - kernel.shape, + B.shape, dtype, name="bb_buffer", offset_factor=1, @@ -785,14 +824,26 @@ def _instr(index): for i in range(0, rows): ib.emit(outs[0].vstore([i, 0], tvm.tir.const(0, "int32x16"))) return ib.get() - + # Iterate on the number of rows of the output for k in range(0, rows): + # Load 16 elements of A + # vec_a = [a, b, c, e, f, g, h, i, l, m, n, o, p, q, r,]; vec_a = ins[0].vload([k, 0], dtype_vec) + # Iterate over each column of the output for j in range(0, 4): + # Accumulate over each of the 4 (16x4) tiles contained in B for i in range(0, 4): + # As before, replicate a single 4-element group of A vec_aa = select_word(vec_a, i, dtype_vec) + # Load 4 rows (each rows with 4 elements) from B + # vec_b = [0, 16, 32, 48, + # 1, 17, 33, 49, + # 2, 18, 34, 50, + # 3, 19, 35, 51,]; vec_b = ins[1].vload([i, 4 * j, 0], dtype_vec) + # Store the result of the accumulation in the + # correct part of the output vec_c = outs[0].vload([k, 4 * j], "int32x4") vdot = tvm.tir.call_llvm_intrin( "int32x4", @@ -812,7 +863,7 @@ def _instr(index): return te.decl_tensor_intrin( C.op, _intrin_func, - binds={data: aa_buffer, kernel: bb_buffer, C: cc_buffer}, + binds={A: aa_buffer, B: bb_buffer, C: cc_buffer}, default_buffer_params=buffer_params, )