Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ynnpack/base/simd/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ cc_library(
"x86_avx512bw.h",
"x86_avx512f.h",
"x86_sse2.h",
"x86_sse2_only.h",
"x86_sse41.h",
],
"//ynnpack:arm": ["arm_neon.h"],
Expand Down
15 changes: 12 additions & 3 deletions ynnpack/base/simd/arm_neon.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
#define XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
#ifndef XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_
#define XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_

#include <arm_neon.h>

Expand All @@ -19,6 +19,7 @@
#include "ynnpack/base/bfloat16.h"
#include "ynnpack/base/half.h"
#include "ynnpack/base/simd/vec.h"
#include "ynnpack/base/simd/multi_vec.h"

namespace ynn {

Expand Down Expand Up @@ -467,6 +468,14 @@ YNN_ALWAYS_INLINE s8x16 max(s8x16 a, s8x16 b) {
return s8x16{vmaxq_s8(a.v, b.v)};
}

using f32x4x2 = multi_vec<f32x4, 2>;
YNN_ALWAYS_INLINE f32x4x2 convert(bf16x8 a, float) {
f32x4x2 result;
result.v[0].v = vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(a.v), 16));
result.v[1].v = vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(a.v), 16));
return result;
}

#ifdef YNN_ARCH_ARM32
YNN_ALWAYS_INLINE float vmaxvq_f32(float32x4_t a) {
float32x2_t max_halves = vmax_f32(vget_low_f32(a), vget_high_f32(a));
Expand Down Expand Up @@ -562,4 +571,4 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(

} // namespace ynn

#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_H_
#endif // XNNPACK_YNNPACK_BASE_SIMD_ARM_NEON_H_
3 changes: 3 additions & 0 deletions ynnpack/base/simd/vec.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ std::array<vec<T, 4>, 4> transpose(std::array<vec<T, 4>, 4> x);
template <int Index, typename T, typename SliceT>
SliceT extract(T, SliceT);

template<typename ToType, typename FromType>
ToType convert(FromType from_type, ToType) = delete;

namespace internal {

template <typename T, size_t N>
Expand Down
33 changes: 33 additions & 0 deletions ynnpack/base/simd/x86_avx.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include "ynnpack/base/base.h"
#include "ynnpack/base/bfloat16.h"
#include "ynnpack/base/half.h"
#include "ynnpack/base/simd/multi_vec.h"
#include "ynnpack/base/simd/vec.h"
#include "ynnpack/base/simd/x86_sse41.h" // IWYU pragma: export

Expand Down Expand Up @@ -386,6 +387,11 @@ YNN_ALWAYS_INLINE f32x4 extract(f32x8 x, f32x4) {
return f32x4{_mm256_extractf128_ps(x.v, Index)};
}
template <int Index>
YNN_ALWAYS_INLINE bf16x8 extract(bf16x16 x, bf16x8) {
return bf16x8{
_mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(x.v), Index))};
}
template <int Index>
YNN_ALWAYS_INLINE f16x8 extract(f16x16 x, f16x8) {
return f16x8{
_mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(x.v), Index))};
Expand All @@ -401,6 +407,33 @@ YNN_ALWAYS_INLINE u8x16 extract(u8x32 x, u8x16) {
_mm_castps_si128(_mm256_extractf128_ps(_mm256_castsi256_ps(x.v), Index))};
}

YNN_ALWAYS_INLINE f32x8 convert(bf16x8 a, float) {
return f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(_mm256_cvtepu16_epi32(a.v),
16))};
}

using s32x8x4 = multi_vec<s32x8, 4>;
YNN_ALWAYS_INLINE s32x8x4 convert(s8x32 a, int32_t) {
s32x8x4 result;
s8x16 lo = extract<0>(a, s8x16{});
s8x16 hi = extract<1>(a, s8x16{});
result.v[0].v = _mm256_cvtepi8_epi32(lo.v);
result.v[1].v = _mm256_cvtepi8_epi32(_mm_srli_si128(lo.v, 8));
result.v[2].v = _mm256_cvtepi8_epi32(hi.v);
result.v[3].v = _mm256_cvtepi8_epi32(_mm_srli_si128(hi.v, 8));
return result;
}
YNN_ALWAYS_INLINE s32x8x4 convert(u8x32 a, int32_t) {
s32x8x4 result;
u8x16 lo = extract<0>(a, u8x16{});
u8x16 hi = extract<1>(a, u8x16{});
result.v[0].v = _mm256_cvtepu8_epi32(lo.v);
result.v[1].v = _mm256_cvtepu8_epi32(_mm_srli_si128(lo.v, 8));
result.v[2].v = _mm256_cvtepu8_epi32(hi.v);
result.v[3].v = _mm256_cvtepu8_epi32(_mm_srli_si128(hi.v, 8));
return result;
}

} // namespace simd

} // namespace ynn
Expand Down
4 changes: 0 additions & 4 deletions ynnpack/base/simd/x86_avx512bw.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,16 @@

#include <immintrin.h>

#include <array>
#include <cassert>
#include <cstddef>
#include <cstdint>
#include <cstring>
#include <type_traits>

#include "ynnpack/base/base.h"
#include "ynnpack/base/bfloat16.h"
#include "ynnpack/base/half.h"
#include "ynnpack/base/simd/vec.h"
#include "ynnpack/base/simd/x86_avx.h"
#include "ynnpack/base/simd/x86_avx512f.h" // IWYU pragma: export
#include "ynnpack/base/simd/x86_sse2.h"

namespace ynn {

Expand Down
12 changes: 10 additions & 2 deletions ynnpack/base/simd/x86_avx512f.h
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,14 @@ YNN_ALWAYS_INLINE s32x16 max(s32x16 a, s32x16 b) {
return s32x16{_mm512_max_epi32(a.v, b.v)};
}

YNN_ALWAYS_INLINE f32x16 convert(f16x16 x, float) {
return f32x16{_mm512_cvtph_ps(x.v)};
}
YNN_ALWAYS_INLINE f32x16 convert(bf16x16 a, float) {
return f32x16{_mm512_castsi512_ps(_mm512_slli_epi32(
_mm512_cvtepu16_epi32(a.v), 16))};
}

YNN_ALWAYS_INLINE float horizontal_max(f32x16 a) {
const __m512 swapped = _mm512_shuffle_f32x4(a.v, a.v, 0x4E);
const __m512 max512 = _mm512_max_ps(a.v, swapped);
Expand Down Expand Up @@ -496,8 +504,8 @@ YNN_ALWAYS_INLINE u8x16 extract(u8x64 x, u8x16) {

template <int Index>
YNN_ALWAYS_INLINE bf16x16 extract(bf16x32 x, bf16x16) {
return bf16x16{_mm256_castps_si256(
_mm512_extractf32x8_ps(_mm512_castsi512_ps(x.v), Index))};
return bf16x16{_mm256_castpd_si256(
_mm512_extractf64x4_pd(_mm512_castsi512_pd(x.v), Index))};
}
template <int Index>
YNN_ALWAYS_INLINE f16x16 extract(f16x32 x, f16x16) {
Expand Down
6 changes: 3 additions & 3 deletions ynnpack/base/simd/x86_sse2.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#ifndef XNNPACK_YNNPACK_BASE_SIMD_X86_SSE_H_
#define XNNPACK_YNNPACK_BASE_SIMD_X86_SSE_H_
#ifndef XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_H_
#define XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_H_

#include <immintrin.h>

Expand Down Expand Up @@ -525,4 +525,4 @@ YNN_ALWAYS_INLINE std::array<vec<T, 4>, 4> transpose(

} // namespace ynn

#endif // XNNPACK_YNNPACK_BASE_SIMD_X86_SSE_H_
#endif // XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_H_
63 changes: 63 additions & 0 deletions ynnpack/base/simd/x86_sse2_only.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
// Copyright 2025 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#ifndef XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_ONLY_H_
#define XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_ONLY_H_

#include <immintrin.h>

#include <array>
#include <cassert>
#include <cstdint>

#include "ynnpack/base/base.h"
#include "ynnpack/base/simd/multi_vec.h"
#include "ynnpack/base/simd/x86_sse2.h"

namespace ynn {

namespace simd {

using f32x4x2 = multi_vec<f32x4, 2>;
YNN_ALWAYS_INLINE f32x4x2 convert(bf16x8 a, float) {
f32x4x2 result;
__m128i zero = _mm_setzero_si128();
__m128i lo = _mm_unpacklo_epi16(a.v, zero);
__m128i hi = _mm_unpackhi_epi16(a.v, zero);
result.v[0].v = _mm_castsi128_ps(_mm_slli_epi32(lo, 16));
result.v[1].v = _mm_castsi128_ps(_mm_slli_epi32(hi, 16));
return result;
}

using s32x4x4 = multi_vec<s32x4, 4>;
YNN_ALWAYS_INLINE s32x4x4 convert(s8x16 a, int32_t) {
s32x4x4 result;
__m128i lo = _mm_unpacklo_epi8(a.v, a.v);
__m128i hi = _mm_unpackhi_epi8(a.v, a.v);

result.v[0].v = _mm_srai_epi32(_mm_unpacklo_epi16(lo, lo), 24);
result.v[1].v = _mm_srai_epi32(_mm_unpackhi_epi16(lo, lo), 24);
result.v[2].v = _mm_srai_epi32(_mm_unpacklo_epi16(hi, hi), 24);
result.v[3].v = _mm_srai_epi32(_mm_unpackhi_epi16(hi, hi), 24);
return result;
}
YNN_ALWAYS_INLINE s32x4x4 convert(u8x16 a, int32_t) {
s32x4x4 result;
const __m128i zero = _mm_setzero_si128();
__m128i lo = _mm_unpacklo_epi8(a.v, zero);
__m128i hi = _mm_unpackhi_epi8(a.v, zero);

result.v[0].v = _mm_unpacklo_epi16(lo, zero);
result.v[1].v = _mm_unpackhi_epi16(lo, zero);
result.v[2].v = _mm_unpacklo_epi16(hi, zero);
result.v[3].v = _mm_unpackhi_epi16(hi, zero);
return result;
}

} // namespace simd

} // namespace ynn

#endif // XNNPACK_YNNPACK_BASE_SIMD_X86_SSE2_ONLY_H_
20 changes: 20 additions & 0 deletions ynnpack/base/simd/x86_sse41.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <cstdint>

#include "ynnpack/base/base.h"
#include "ynnpack/base/simd/multi_vec.h"
#include "ynnpack/base/simd/x86_sse2.h" // IWYU pragma: export

namespace ynn {
Expand All @@ -38,6 +39,25 @@ YNN_ALWAYS_INLINE s32x4 max(s32x4 a, s32x4 b) {
return s32x4{_mm_max_epi32(a.v, b.v)};
}

using s32x4x4 = multi_vec<s32x4, 4>;
YNN_ALWAYS_INLINE s32x4x4 convert(s8x16 a, int32_t) {
s32x4x4 result;
result.v[0].v = _mm_cvtepi8_epi32(a.v);
result.v[1].v = _mm_cvtepi8_epi32(_mm_srli_si128(a.v, 4));
result.v[2].v = _mm_cvtepi8_epi32(_mm_srli_si128(a.v, 8));
result.v[3].v = _mm_cvtepi8_epi32(_mm_srli_si128(a.v, 12));
return result;
}

YNN_ALWAYS_INLINE s32x4x4 convert(u8x16 a, int32_t) {
s32x4x4 result;
result.v[0].v = _mm_cvtepu8_epi32(a.v);
result.v[1].v = _mm_cvtepu8_epi32(_mm_srli_si128(a.v, 4));
result.v[2].v = _mm_cvtepu8_epi32(_mm_srli_si128(a.v, 8));
result.v[3].v = _mm_cvtepu8_epi32(_mm_srli_si128(a.v, 12));
return result;
}

YNN_ALWAYS_INLINE int8_t horizontal_max(s8x16 a) {
const __m128i max8 = _mm_max_epi8(a.v, _mm_srli_si128(a.v, 8));
const __m128i max4 = _mm_max_epi8(max8, _mm_srli_si128(max8, 4));
Expand Down
17 changes: 6 additions & 11 deletions ynnpack/kernels/reduce/arm_neon.cc
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,9 @@ static f32x4x16 reduce_add(
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
YNN_UNROLL
for (int i = 0; i < 8; ++i) {
f32x4 lo(vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16)));
f32x4 hi(vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16)));

a.v[2 * i + 0] += lo;
a.v[2 * i + 1] += hi;
f32x4x2 b_f32 = convert(b.v[i], float{});
a.v[2 * i + 0] += b_f32.v[0];
a.v[2 * i + 1] += b_f32.v[1];
}

return a;
Expand All @@ -60,12 +58,9 @@ static f32x4x16 reduce_add(
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
YNN_UNROLL
for (int i = 0; i < 8; ++i) {
float32x4_t lo =
vreinterpretq_f32_u32(vshll_n_u16(vget_low_u16(b.v[i].v), 16));
float32x4_t hi =
vreinterpretq_f32_u32(vshll_n_u16(vget_high_u16(b.v[i].v), 16));
a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, lo, lo);
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, hi, hi);
f32x4x2 b_f32 = convert(b.v[i], float{});
a.v[2 * i + 0].v = vmlaq_f32(a.v[2 * i + 0].v, b_f32.v[0].v, b_f32.v[0].v);
a.v[2 * i + 1].v = vmlaq_f32(a.v[2 * i + 1].v, b_f32.v[1].v, b_f32.v[1].v);
}

return a;
Expand Down
49 changes: 11 additions & 38 deletions ynnpack/kernels/reduce/x86_avx2.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "ynnpack/base/simd/x86_avx2.h"

#include <immintrin.h>

#include <cassert>
Expand All @@ -16,6 +14,8 @@
#include "ynnpack/base/base.h"
#include "ynnpack/base/bfloat16.h"
#include "ynnpack/base/half.h"
#include "ynnpack/base/simd/x86_sse2.h"
#include "ynnpack/base/simd/x86_avx2.h"
#include "ynnpack/base/simd/multi_vec.h"
#include "ynnpack/kernels/reduce/generic.h"
#include "ynnpack/kernels/reduce/min_max_accumulator.h"
Expand All @@ -29,36 +29,15 @@ namespace simd {
using f32x8x8 = multi_vec<f32x8, 8>;
using f32x8x16 = multi_vec<f32x8, 16>;
using s32x8x2 = multi_vec<s32x8, 2>;
using s32x8x4 = multi_vec<s32x8, 4>;
using bf16x16x8 = multi_vec<bf16x16, 8>;

static s32x8x4& operator+=(s32x8x4& a, s8x32 b) {
s8x16 b_lo = extract<0>(b, s8x16{});
s8x16 b_hi = extract<1>(b, s8x16{});
s32x8 b_0(_mm256_cvtepi8_epi32(b_lo.v));
s32x8 b_1(_mm256_cvtepi8_epi32(_mm_srli_si128(b_lo.v, 8)));
s32x8 b_2(_mm256_cvtepi8_epi32(b_hi.v));
s32x8 b_3(_mm256_cvtepi8_epi32(_mm_srli_si128(b_hi.v, 8)));

a.v[0] += b_0;
a.v[1] += b_1;
a.v[2] += b_2;
a.v[3] += b_3;
a += convert(b, int32_t{});
return a;
}

static s32x8x4& operator+=(s32x8x4& a, u8x32 b) {
u8x16 b_lo = extract<0>(b, u8x16{});
u8x16 b_hi = extract<1>(b, u8x16{});
s32x8 b_0(_mm256_cvtepu8_epi32(b_lo.v));
s32x8 b_1(_mm256_cvtepu8_epi32(_mm_srli_si128(b_lo.v, 8)));
s32x8 b_2(_mm256_cvtepu8_epi32(b_hi.v));
s32x8 b_3(_mm256_cvtepu8_epi32(_mm_srli_si128(b_hi.v, 8)));

a.v[0] += b_0;
a.v[1] += b_1;
a.v[2] += b_2;
a.v[3] += b_3;
a += convert(b, int32_t{});
return a;
}

Expand Down Expand Up @@ -123,11 +102,8 @@ static f32x8x16 reduce_add(
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
YNN_UNROLL
for (int i = 0; i < 8; ++i) {
__m256i lo = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
__m256i hi = _mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));

a.v[2 * i + 0] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(lo, 16))};
a.v[2 * i + 1] += f32x8{_mm256_castsi256_ps(_mm256_slli_epi32(hi, 16))};
a.v[2 * i + 0] += convert(extract<0>(b.v[i], bf16x8{}), float{});
a.v[2 * i + 1] += convert(extract<1>(b.v[i], bf16x8{}), float{});
}

return a;
Expand All @@ -150,14 +126,11 @@ static f32x8x16 reduce_add(
std::integral_constant<size_t, 1> /*horizontal_factor*/) {
YNN_UNROLL
for (int i = 0; i < 8; ++i) {
__m256i lo_u32 = _mm256_cvtepu16_epi32(_mm256_castsi256_si128(b.v[i].v));
__m256i hi_u32 =
_mm256_cvtepu16_epi32(_mm256_extracti128_si256(b.v[i].v, 1));
f32x8 lo_f32(_mm256_castsi256_ps(_mm256_slli_epi32(lo_u32, 16)));
f32x8 hi_f32(_mm256_castsi256_ps(_mm256_slli_epi32(hi_u32, 16)));

a.v[2 * i + 0] += lo_f32 * lo_f32;
a.v[2 * i + 1] += hi_f32 * hi_f32;
f32x8 b_lo = convert(extract<0>(b.v[i], bf16x8{}), float{});
f32x8 b_hi = convert(extract<1>(b.v[i], bf16x8{}), float{});

a.v[2 * i + 0] += b_lo * b_lo;
a.v[2 * i + 1] += b_hi * b_hi;
}

return a;
Expand Down
Loading
Loading