Skip to content

Commit

Permalink
fast path
Browse files Browse the repository at this point in the history
  • Loading branch information
nihui committed Oct 8, 2024
1 parent 03168d7 commit a280bfb
Show file tree
Hide file tree
Showing 3 changed files with 250 additions and 104 deletions.
124 changes: 88 additions & 36 deletions src/layer/arm/gemm_int8.h
Original file line number Diff line number Diff line change
Expand Up @@ -12352,10 +12352,21 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
_f3 = vmulq_f32(_f3, _alpha);
}

vst1q_f32(p0, _f0);
vst1q_f32(p0 + out_hstep * 4, _f1);
vst1q_f32(p0 + out_hstep * 8, _f2);
vst1q_f32(p0 + out_hstep * 12, _f3);
if (out_hstep == 1)
{
vst1q_f32(p0, _f0);
vst1q_f32(p0 + 4, _f1);
vst1q_f32(p0 + 8, _f2);
vst1q_f32(p0 + 12, _f3);
}
else
{
vst1q_f32(p0, _f0);
vst1q_f32(p0 + out_hstep * 4, _f1);
vst1q_f32(p0 + out_hstep * 8, _f2);
vst1q_f32(p0 + out_hstep * 12, _f3);
}

pp += 16;
p0 += out_hstep * 16;
}
Expand Down Expand Up @@ -12401,8 +12412,17 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
_f1 = vmulq_f32(_f1, _alpha);
}

vst1q_f32(p0, _f0);
vst1q_f32(p0 + out_hstep * 4, _f1);
if (out_hstep == 1)
{
vst1q_f32(p0, _f0);
vst1q_f32(p0 + 4, _f1);
}
else
{
vst1q_f32(p0, _f0);
vst1q_f32(p0 + out_hstep * 4, _f1);
}

pp += 8;
p0 += out_hstep * 8;
}
Expand Down Expand Up @@ -12493,22 +12513,32 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
_f3 = vmulq_f32(_f3, _alpha);
}

p0[0] = vgetq_lane_f32(_f0, 0);
p0[out_hstep] = vgetq_lane_f32(_f0, 1);
p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2);
p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3);
p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0);
p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1);
p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2);
p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3);
p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0);
p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1);
p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2);
p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3);
p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0);
p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1);
p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2);
p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3);
if (out_hstep == 1)
{
vst1q_f32(p0, _f0);
vst1q_f32(p0 + 4, _f1);
vst1q_f32(p0 + 8, _f2);
vst1q_f32(p0 + 12, _f3);
}
else
{
p0[0] = vgetq_lane_f32(_f0, 0);
p0[out_hstep] = vgetq_lane_f32(_f0, 1);
p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2);
p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3);
p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0);
p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1);
p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2);
p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3);
p0[out_hstep * 8] = vgetq_lane_f32(_f2, 0);
p0[out_hstep * 9] = vgetq_lane_f32(_f2, 1);
p0[out_hstep * 10] = vgetq_lane_f32(_f2, 2);
p0[out_hstep * 11] = vgetq_lane_f32(_f2, 3);
p0[out_hstep * 12] = vgetq_lane_f32(_f3, 0);
p0[out_hstep * 13] = vgetq_lane_f32(_f3, 1);
p0[out_hstep * 14] = vgetq_lane_f32(_f3, 2);
p0[out_hstep * 15] = vgetq_lane_f32(_f3, 3);
}

pp += 16;
p0 += out_hstep * 16;
Expand Down Expand Up @@ -12555,14 +12585,22 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma
_f1 = vmulq_f32(_f1, _alpha);
}

p0[0] = vgetq_lane_f32(_f0, 0);
p0[out_hstep] = vgetq_lane_f32(_f0, 1);
p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2);
p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3);
p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0);
p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1);
p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2);
p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3);
if (out_hstep == 1)
{
vst1q_f32(p0, _f0);
vst1q_f32(p0 + 4, _f1);
}
else
{
p0[0] = vgetq_lane_f32(_f0, 0);
p0[out_hstep] = vgetq_lane_f32(_f0, 1);
p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2);
p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3);
p0[out_hstep * 4] = vgetq_lane_f32(_f1, 0);
p0[out_hstep * 5] = vgetq_lane_f32(_f1, 1);
p0[out_hstep * 6] = vgetq_lane_f32(_f1, 2);
p0[out_hstep * 7] = vgetq_lane_f32(_f1, 3);
}

pp += 8;
p0 += out_hstep * 8;
Expand All @@ -12588,10 +12626,17 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma

_f0 = vmulq_n_f32(_f0, alpha);

p0[0] = vgetq_lane_f32(_f0, 0);
p0[out_hstep] = vgetq_lane_f32(_f0, 1);
p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2);
p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3);
if (out_hstep == 1)
{
vst1q_f32(p0, _f0);
}
else
{
p0[0] = vgetq_lane_f32(_f0, 0);
p0[out_hstep] = vgetq_lane_f32(_f0, 1);
p0[out_hstep * 2] = vgetq_lane_f32(_f0, 2);
p0[out_hstep * 3] = vgetq_lane_f32(_f0, 3);
}

pp += 4;
p0 += out_hstep * 4;
Expand All @@ -12617,8 +12662,15 @@ static void transpose_unpack_output_tile_int32_to_fp32(const Mat& topT, const Ma

_f0 = vmul_n_f32(_f0, alpha);

p0[0] = vget_lane_f32(_f0, 0);
p0[out_hstep] = vget_lane_f32(_f0, 1);
if (out_hstep == 1)
{
vst1_f32(p0, _f0);
}
else
{
p0[0] = vget_lane_f32(_f0, 0);
p0[out_hstep] = vget_lane_f32(_f0, 1);
}

pp += 2;
p0 += out_hstep * 2;
Expand Down
115 changes: 81 additions & 34 deletions src/layer/arm/gemm_int8_bf16s.h
Original file line number Diff line number Diff line change
Expand Up @@ -11035,10 +11035,24 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
_f3 = vmulq_f32(_f3, _alpha);
}

vst1_u16(p0, float2bfloat(_f0));
vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1));
vst1_u16(p0 + out_hstep * 8, float2bfloat(_f2));
vst1_u16(p0 + out_hstep * 12, float2bfloat(_f3));
uint16x4_t _bf0 = float2bfloat(_f0);
uint16x4_t _bf1 = float2bfloat(_f1);
uint16x4_t _bf2 = float2bfloat(_f2);
uint16x4_t _bf3 = float2bfloat(_f3);

if (out_hstep == 1)
{
vst1q_u16(p0, vcombine_u16(_bf0, _bf1));
vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3));
}
else
{
vst1_u16(p0, _bf0);
vst1_u16(p0 + out_hstep * 4, _bf1);
vst1_u16(p0 + out_hstep * 8, _bf2);
vst1_u16(p0 + out_hstep * 12, _bf3);
}

pp += 16;
p0 += out_hstep * 16;
}
Expand Down Expand Up @@ -11085,8 +11099,19 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
_f1 = vmulq_f32(_f1, _alpha);
}

vst1_u16(p0, float2bfloat(_f0));
vst1_u16(p0 + out_hstep * 4, float2bfloat(_f1));
uint16x4_t _bf0 = float2bfloat(_f0);
uint16x4_t _bf1 = float2bfloat(_f1);

if (out_hstep == 1)
{
vst1q_u16(p0, vcombine_u16(_bf0, _bf1));
}
else
{
vst1_u16(p0, _bf0);
vst1_u16(p0 + out_hstep * 4, _bf1);
}

pp += 8;
p0 += out_hstep * 8;
}
Expand Down Expand Up @@ -11183,22 +11208,30 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
uint16x4_t _bf2 = float2bfloat(_f2);
uint16x4_t _bf3 = float2bfloat(_f3);

p0[0] = vget_lane_u16(_bf0, 0);
p0[out_hstep] = vget_lane_u16(_bf0, 1);
p0[out_hstep * 2] = vget_lane_u16(_bf0, 2);
p0[out_hstep * 3] = vget_lane_u16(_bf0, 3);
p0[out_hstep * 4] = vget_lane_u16(_bf1, 0);
p0[out_hstep * 5] = vget_lane_u16(_bf1, 1);
p0[out_hstep * 6] = vget_lane_u16(_bf1, 2);
p0[out_hstep * 7] = vget_lane_u16(_bf1, 3);
p0[out_hstep * 8] = vget_lane_u16(_bf2, 0);
p0[out_hstep * 9] = vget_lane_u16(_bf2, 1);
p0[out_hstep * 10] = vget_lane_u16(_bf2, 2);
p0[out_hstep * 11] = vget_lane_u16(_bf2, 3);
p0[out_hstep * 12] = vget_lane_u16(_bf3, 0);
p0[out_hstep * 13] = vget_lane_u16(_bf3, 1);
p0[out_hstep * 14] = vget_lane_u16(_bf3, 2);
p0[out_hstep * 15] = vget_lane_u16(_bf3, 3);
if (out_hstep == 1)
{
vst1q_u16(p0, vcombine_u16(_bf0, _bf1));
vst1q_u16(p0 + 8, vcombine_u16(_bf2, _bf3));
}
else
{
p0[0] = vget_lane_u16(_bf0, 0);
p0[out_hstep] = vget_lane_u16(_bf0, 1);
p0[out_hstep * 2] = vget_lane_u16(_bf0, 2);
p0[out_hstep * 3] = vget_lane_u16(_bf0, 3);
p0[out_hstep * 4] = vget_lane_u16(_bf1, 0);
p0[out_hstep * 5] = vget_lane_u16(_bf1, 1);
p0[out_hstep * 6] = vget_lane_u16(_bf1, 2);
p0[out_hstep * 7] = vget_lane_u16(_bf1, 3);
p0[out_hstep * 8] = vget_lane_u16(_bf2, 0);
p0[out_hstep * 9] = vget_lane_u16(_bf2, 1);
p0[out_hstep * 10] = vget_lane_u16(_bf2, 2);
p0[out_hstep * 11] = vget_lane_u16(_bf2, 3);
p0[out_hstep * 12] = vget_lane_u16(_bf3, 0);
p0[out_hstep * 13] = vget_lane_u16(_bf3, 1);
p0[out_hstep * 14] = vget_lane_u16(_bf3, 2);
p0[out_hstep * 15] = vget_lane_u16(_bf3, 3);
}

pp += 16;
p0 += out_hstep * 16;
Expand Down Expand Up @@ -11249,14 +11282,21 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma
uint16x4_t _bf0 = float2bfloat(_f0);
uint16x4_t _bf1 = float2bfloat(_f1);

p0[0] = vget_lane_u16(_bf0, 0);
p0[out_hstep] = vget_lane_u16(_bf0, 1);
p0[out_hstep * 2] = vget_lane_u16(_bf0, 2);
p0[out_hstep * 3] = vget_lane_u16(_bf0, 3);
p0[out_hstep * 4] = vget_lane_u16(_bf1, 0);
p0[out_hstep * 5] = vget_lane_u16(_bf1, 1);
p0[out_hstep * 6] = vget_lane_u16(_bf1, 2);
p0[out_hstep * 7] = vget_lane_u16(_bf1, 3);
if (out_hstep == 1)
{
vst1q_u16(p0, vcombine_u16(_bf0, _bf1));
}
else
{
p0[0] = vget_lane_u16(_bf0, 0);
p0[out_hstep] = vget_lane_u16(_bf0, 1);
p0[out_hstep * 2] = vget_lane_u16(_bf0, 2);
p0[out_hstep * 3] = vget_lane_u16(_bf0, 3);
p0[out_hstep * 4] = vget_lane_u16(_bf1, 0);
p0[out_hstep * 5] = vget_lane_u16(_bf1, 1);
p0[out_hstep * 6] = vget_lane_u16(_bf1, 2);
p0[out_hstep * 7] = vget_lane_u16(_bf1, 3);
}

pp += 8;
p0 += out_hstep * 8;
Expand Down Expand Up @@ -11284,10 +11324,17 @@ static void transpose_unpack_output_tile_int32_to_bf16(const Mat& topT, const Ma

uint16x4_t _bf0 = float2bfloat(_f0);

p0[0] = vget_lane_u16(_bf0, 0);
p0[out_hstep] = vget_lane_u16(_bf0, 1);
p0[out_hstep * 2] = vget_lane_u16(_bf0, 2);
p0[out_hstep * 3] = vget_lane_u16(_bf0, 3);
if (out_hstep == 1)
{
vst1_u16(p0, _bf0);
}
else
{
p0[0] = vget_lane_u16(_bf0, 0);
p0[out_hstep] = vget_lane_u16(_bf0, 1);
p0[out_hstep * 2] = vget_lane_u16(_bf0, 2);
p0[out_hstep * 3] = vget_lane_u16(_bf0, 3);
}

pp += 4;
p0 += out_hstep * 4;
Expand Down
Loading

0 comments on commit a280bfb

Please sign in to comment.