From a280bfb33c269a2ed54371f75fb762355adfdb62 Mon Sep 17 00:00:00 2001 From: nihuini Date: Tue, 8 Oct 2024 17:12:48 +0800 Subject: [PATCH] fast path --- src/layer/arm/gemm_int8.h | 124 ++++++++++++++++++++++---------- src/layer/arm/gemm_int8_bf16s.h | 115 ++++++++++++++++++++--------- src/layer/arm/gemm_int8_fp16s.h | 115 ++++++++++++++++++++--------- 3 files changed, 250 insertions(+), 104 deletions(-) diff --git a/src/layer/arm/gemm_int8.h b/src/layer/arm/gemm_int8.h index e54c5687463..49c946d8fc0 100644 --- a/src/layer/arm/gemm_int8.h +++ b/src/layer/arm/gemm_int8.h @@ -12352,10 +12352,21 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); - vst1q_f32(p0 + out_hstep * 8, _f2); - vst1q_f32(p0 + out_hstep * 12, _f3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + } + else + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + vst1q_f32(p0 + out_hstep * 8, _f2); + vst1q_f32(p0 + out_hstep * 12, _f3); + } + pp += 16; p0 += out_hstep * 16; } @@ -12401,8 +12412,17 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - vst1q_f32(p0, _f0); - vst1q_f32(p0 + out_hstep * 4, _f1); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + else + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + out_hstep * 4, _f1); + } + pp += 8; p0 += out_hstep * 8; } @@ -12493,22 +12513,32 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); - p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); - p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); - p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); - p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); - p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0); - p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1); - p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2); - p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3); - p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0); - p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1); - p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); - p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + vst1q_f32(p0 + 8, _f2); + vst1q_f32(p0 + 12, _f3); + } + else + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0); + p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1); + p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2); + p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3); + p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0); + p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1); + p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2); + p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3); + } pp += 16; p0 += out_hstep * 16; @@ -12555,14 +12585,22 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); - p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); - p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); - p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); - p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + vst1q_f32(p0 + 4, _f1); + } + else + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0); + p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1); + p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2); + p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3); + } pp += 8; p0 += out_hstep * 8; @@ -12588,10 +12626,17 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f0 = vmulq_n_f32(_f0, alpha); - p0[0] = vgetq_lane_f32(_f0, 0); - p0[out_hstep] = vgetq_lane_f32(_f0, 1); - p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); - p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + if (out_hstep == 1) + { + vst1q_f32(p0, _f0); + } + else + { + p0[0] = vgetq_lane_f32(_f0, 0); + p0[out_hstep] = vgetq_lane_f32(_f0, 1); + p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2); + p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3); + } pp += 4; p0 += out_hstep * 4; @@ -12617,8 +12662,15 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma _f0 = vmul_n_f32(_f0, alpha); - p0[0] = vget_lane_f32(_f0, 0); - p0[out_hstep] = vget_lane_f32(_f0, 1); + if (out_hstep == 1) + { + vst1_f32(p0, _f0); + } + else + { + p0[0] = vget_lane_f32(_f0, 0); + p0[out_hstep] = vget_lane_f32(_f0, 1); + } pp += 2; p0 += out_hstep * 2; diff --git a/src/layer/arm/gemm_int8_bf16s.h b/src/layer/arm/gemm_int8_bf16s.h index 298e156b0db..36a4e423031 100644 --- a/src/layer/arm/gemm_int8_bf16s.h +++ b/src/layer/arm/gemm_int8_bf16s.h @@ -11035,10 +11035,24 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); - vst1_u16(p0 + out_hstep * 8, float2bfloat(_f2)); - vst1_u16(p0 + out_hstep * 12, float2bfloat(_f3)); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + uint16x4_t _bf2 = float2bfloat(_f2); + uint16x4_t _bf3 = float2bfloat(_f3); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + } + else + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + vst1_u16(p0 + out_hstep * 8, _bf2); + vst1_u16(p0 + out_hstep * 12, _bf3); + } + pp += 16; p0 += out_hstep * 16; } @@ -11085,8 +11099,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - vst1_u16(p0, float2bfloat(_f0)); - vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1)); + uint16x4_t _bf0 = float2bfloat(_f0); + uint16x4_t _bf1 = float2bfloat(_f1); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + } + else + { + vst1_u16(p0, _bf0); + vst1_u16(p0 + out_hstep * 4, _bf1); + } + pp += 8; p0 += out_hstep * 8; } @@ -11183,22 +11208,30 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _bf2 = float2bfloat(_f2); uint16x4_t _bf3 = float2bfloat(_f3); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3)); + } + else + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_bf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_bf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_bf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_bf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_bf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_bf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_bf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_bf3, 3); + } pp += 16; p0 += out_hstep * 16; @@ -11249,14 +11282,21 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _bf0 = float2bfloat(_f0); uint16x4_t _bf1 = float2bfloat(_f1); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_bf0, _bf1)); + } + else + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_bf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_bf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_bf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_bf1, 3); + } pp += 8; p0 += out_hstep * 8; @@ -11284,10 +11324,17 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma uint16x4_t _bf0 = float2bfloat(_f0); - p0[0] = vget_lane_u16(_bf0, 0); - p0[out_hstep] = vget_lane_u16(_bf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + if (out_hstep == 1) + { + vst1_u16(p0, _bf0); + } + else + { + p0[0] = vget_lane_u16(_bf0, 0); + p0[out_hstep] = vget_lane_u16(_bf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_bf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_bf0, 3); + } pp += 4; p0 += out_hstep * 4; diff --git a/src/layer/arm/gemm_int8_fp16s.h b/src/layer/arm/gemm_int8_fp16s.h index 0cfb7bb8490..629bec91ad5 100644 --- a/src/layer/arm/gemm_int8_fp16s.h +++ b/src/layer/arm/gemm_int8_fp16s.h @@ -13113,10 +13113,24 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f3 = vmulq_f32(_f3, _alpha); } - vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); - vst1_u16(p0 + out_hstep * 4, (uint16x4_t)vcvt_f16_f32(_f1)); - vst1_u16(p0 + out_hstep * 8, (uint16x4_t)vcvt_f16_f32(_f2)); - vst1_u16(p0 + out_hstep * 12, (uint16x4_t)vcvt_f16_f32(_f3)); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); + uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + } + else + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + vst1_u16(p0 + out_hstep * 8, _hf2); + vst1_u16(p0 + out_hstep * 12, _hf3); + } + pp += 16; p0 += out_hstep * 16; } @@ -13163,8 +13177,19 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma _f1 = vmulq_f32(_f1, _alpha); } - vst1_u16(p0, (uint16x4_t)vcvt_f16_f32(_f0)); - vst1_u16(p0 + out_hstep * 4, (uint16x4_t)vcvt_f16_f32(_f1)); + uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); + uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); + + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } + else + { + vst1_u16(p0, _hf0); + vst1_u16(p0 + out_hstep * 4, _hf1); + } + pp += 8; p0 += out_hstep * 8; } @@ -13261,22 +13286,30 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf2 = (uint16x4_t)vcvt_f16_f32(_f2); uint16x4_t _hf3 = (uint16x4_t)vcvt_f16_f32(_f3); - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); - p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); - p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); - p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); - p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); - p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); - p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); - p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); - p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + vst1q_u16(p0 + 8, vcombine_u16(_hf2, _hf3)); + } + else + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + p0[out_hstep * 8] = vget_lane_u16(_hf2, 0); + p0[out_hstep * 9] = vget_lane_u16(_hf2, 1); + p0[out_hstep * 10] = vget_lane_u16(_hf2, 2); + p0[out_hstep * 11] = vget_lane_u16(_hf2, 3); + p0[out_hstep * 12] = vget_lane_u16(_hf3, 0); + p0[out_hstep * 13] = vget_lane_u16(_hf3, 1); + p0[out_hstep * 14] = vget_lane_u16(_hf3, 2); + p0[out_hstep * 15] = vget_lane_u16(_hf3, 3); + } pp += 16; p0 += out_hstep * 16; @@ -13327,14 +13360,21 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); uint16x4_t _hf1 = (uint16x4_t)vcvt_f16_f32(_f1); - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); - p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); - p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); - p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); - p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + if (out_hstep == 1) + { + vst1q_u16(p0, vcombine_u16(_hf0, _hf1)); + } + else + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + p0[out_hstep * 4] = vget_lane_u16(_hf1, 0); + p0[out_hstep * 5] = vget_lane_u16(_hf1, 1); + p0[out_hstep * 6] = vget_lane_u16(_hf1, 2); + p0[out_hstep * 7] = vget_lane_u16(_hf1, 3); + } pp += 8; p0 += out_hstep * 8; @@ -13362,10 +13402,17 @@ static void transpose_unpack_output_tile_int32_to_fp16(const Mat& topT, const Ma uint16x4_t _hf0 = (uint16x4_t)vcvt_f16_f32(_f0); - p0[0] = vget_lane_u16(_hf0, 0); - p0[out_hstep] = vget_lane_u16(_hf0, 1); - p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); - p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + if (out_hstep == 1) + { + vst1_u16(p0, _hf0); + } + else + { + p0[0] = vget_lane_u16(_hf0, 0); + p0[out_hstep] = vget_lane_u16(_hf0, 1); + p0[out_hstep * 2] = vget_lane_u16(_hf0, 2); + p0[out_hstep * 3] = vget_lane_u16(_hf0, 3); + } pp += 4; p0 += out_hstep * 4;