Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 0 additions & 10 deletions ynnpack/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -208,16 +208,6 @@ define_build_option(
default_all = [":ynn_enable_x86_avx"],
)

define_build_option(
name = "ynn_enable_x86_avx512f",
default_all = [":ynn_enable_x86_avx512"],
)

define_build_option(
name = "ynn_enable_x86_avx512bw",
default_all = [":ynn_enable_x86_avx512"],
)

define_build_option(
name = "ynn_enable_x86_avx512bf16",
default_all = [
Expand Down
2 changes: 2 additions & 0 deletions ynnpack/base/arch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ uint64_t get_supported_arch_flags() {
if (cpuinfo_has_x86_fma3()) result |= arch_flag::fma3;
if (cpuinfo_has_x86_avx512f()) result |= arch_flag::avx512f;
if (cpuinfo_has_x86_avx512bw()) result |= arch_flag::avx512bw;
if (cpuinfo_has_x86_avx512vl()) result |= arch_flag::avx512vl;
if (cpuinfo_has_x86_avx512dq()) result |= arch_flag::avx512dq;
if (cpuinfo_has_x86_avx512bf16()) result |= arch_flag::avx512bf16;
if (cpuinfo_has_x86_avx512fp16()) result |= arch_flag::avx512fp16;
if (cpuinfo_has_x86_avx512vnni()) result |= arch_flag::avx512vnni;
Expand Down
15 changes: 9 additions & 6 deletions ynnpack/base/arch.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,17 @@ enum {
fma3 = 1 << 6,
avx512f = 1 << 7,
avx512bw = 1 << 8,
avx512bf16 = 1 << 9,
avx512fp16 = 1 << 10,
avx512vnni = 1 << 11,
amxbf16 = 1 << 12,
amxfp16 = 1 << 13,
amxint8 = 1 << 14,
avx512vl = 1 << 9,
avx512dq = 1 << 10,
avx512bf16 = 1 << 11,
avx512fp16 = 1 << 12,
avx512vnni = 1 << 13,
amxbf16 = 1 << 14,
amxfp16 = 1 << 15,
amxint8 = 1 << 16,

avx2_fma3 = avx2 | fma3,
avx512 = avx512f | avx512bw | avx512vl | avx512dq,
#endif // YNN_ARCH_X86
#ifdef YNN_ARCH_ARM
neon = 1 << 0,
Expand Down
3 changes: 1 addition & 2 deletions ynnpack/base/simd/test/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ package(default_visibility = ["//ynnpack:__subpackages__"])
"x86_sse41",
"x86_avx",
"x86_avx2",
"x86_avx512f",
"x86_avx512bw",
"x86_avx512",
"x86_f16c",
"x86_fma3",
]]
Expand Down
111 changes: 111 additions & 0 deletions ynnpack/base/simd/test/x86_avx512.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
// Copyright 2025 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

#include "ynnpack/base/simd/x86_avx512bw.h"

#include <cstdint>

#include "ynnpack/base/bfloat16.h"
#include "ynnpack/base/half.h"
#include "ynnpack/base/simd/test/generic.h"

namespace ynn {
namespace simd {

TEST_BROADCAST(x86_avx512, uint8_t, 64);
TEST_BROADCAST(x86_avx512, int8_t, 64);
TEST_BROADCAST(x86_avx512, int16_t, 32);
TEST_BROADCAST(x86_avx512, half, 32);
TEST_BROADCAST(x86_avx512, bfloat16, 32);
TEST_BROADCAST(x86_avx512, float, 16);
TEST_BROADCAST(x86_avx512, int32_t, 16);

TEST_LOAD_STORE(x86_avx512, uint8_t, 64);
TEST_LOAD_STORE(x86_avx512, int8_t, 64);
TEST_LOAD_STORE(x86_avx512, int16_t, 32);
TEST_LOAD_STORE(x86_avx512, half, 32);
TEST_LOAD_STORE(x86_avx512, bfloat16, 32);
TEST_LOAD_STORE(x86_avx512, float, 16);
TEST_LOAD_STORE(x86_avx512, int32_t, 16);

TEST_ALIGNED_LOAD_STORE(x86_avx512, uint8_t, 64);
TEST_ALIGNED_LOAD_STORE(x86_avx512, int8_t, 64);
TEST_ALIGNED_LOAD_STORE(x86_avx512, int16_t, 32);
TEST_ALIGNED_LOAD_STORE(x86_avx512, half, 32);
TEST_ALIGNED_LOAD_STORE(x86_avx512, bfloat16, 32);
TEST_ALIGNED_LOAD_STORE(x86_avx512, float, 16);
TEST_ALIGNED_LOAD_STORE(x86_avx512, int32_t, 16);

TEST_PARTIAL_LOAD_STORE(x86_avx512, uint8_t, 64);
TEST_PARTIAL_LOAD_STORE(x86_avx512, int8_t, 64);
TEST_PARTIAL_LOAD_STORE(x86_avx512, int16_t, 32);
TEST_PARTIAL_LOAD_STORE(x86_avx512, half, 32);
TEST_PARTIAL_LOAD_STORE(x86_avx512, bfloat16, 32);
TEST_PARTIAL_LOAD_STORE(x86_avx512, float, 16);
TEST_PARTIAL_LOAD_STORE(x86_avx512, int32_t, 16);

TEST_ADD(x86_avx512, uint8_t, 64);
TEST_ADD(x86_avx512, int8_t, 64);
TEST_ADD(x86_avx512, float, 16);
TEST_ADD(x86_avx512, int32_t, 16);

TEST_SUBTRACT(x86_avx512, uint8_t, 64);
TEST_SUBTRACT(x86_avx512, int8_t, 64);
TEST_SUBTRACT(x86_avx512, float, 16);
TEST_SUBTRACT(x86_avx512, int32_t, 16);

TEST_MULTIPLY(x86_avx512, float, 16);
TEST_MULTIPLY(x86_avx512, int32_t, 16);

TEST_MIN(x86_avx512, uint8_t, 64);
TEST_MIN(x86_avx512, int8_t, 64);
TEST_MIN(x86_avx512, int16_t, 32);
TEST_MIN(x86_avx512, float, 16);
TEST_MIN(x86_avx512, int32_t, 16);

TEST_MAX(x86_avx512, uint8_t, 64);
TEST_MAX(x86_avx512, int8_t, 64);
TEST_MAX(x86_avx512, int16_t, 32);
TEST_MAX(x86_avx512, float, 16);
TEST_MAX(x86_avx512, int32_t, 16);

TEST_FMA(x86_avx512, float, 16);

TEST_EXTRACT(x86_avx512, s32x16, 4);
TEST_EXTRACT(x86_avx512, f32x16, 4);
TEST_EXTRACT(x86_avx512, s8x64, 16);
TEST_EXTRACT(x86_avx512, u8x64, 16);

TEST_EXTRACT(x86_avx512, bf16x32, 16);
TEST_EXTRACT(x86_avx512, f16x32, 16);
TEST_EXTRACT(x86_avx512, s8x64, 32);
TEST_EXTRACT(x86_avx512, u8x64, 32);

TEST_CONCAT(x86_avx512, bf16x16);
TEST_CONCAT(x86_avx512, f16x16);
TEST_CONCAT(x86_avx512, s8x32);
TEST_CONCAT(x86_avx512, u8x32);

TEST_CONVERT(x86_avx512, int32_t, s8x16);
TEST_CONVERT(x86_avx512, int32_t, u8x16);
TEST_CONVERT(x86_avx512, int32_t, s8x32);
TEST_CONVERT(x86_avx512, int32_t, u8x32);
TEST_CONVERT(x86_avx512, float, bf16x16);
TEST_CONVERT(x86_avx512, float, f16x16);

TEST_HORIZONTAL_MIN(x86_avx512, uint8_t, 64);
TEST_HORIZONTAL_MIN(x86_avx512, int8_t, 64);
TEST_HORIZONTAL_MIN(x86_avx512, int16_t, 32);
TEST_HORIZONTAL_MIN(x86_avx512, float, 16);
TEST_HORIZONTAL_MIN(x86_avx512, int32_t, 16);

TEST_HORIZONTAL_MAX(x86_avx512, uint8_t, 64);
TEST_HORIZONTAL_MAX(x86_avx512, int8_t, 64);
TEST_HORIZONTAL_MAX(x86_avx512, int16_t, 32);
TEST_HORIZONTAL_MAX(x86_avx512, float, 16);
TEST_HORIZONTAL_MAX(x86_avx512, int32_t, 16);

} // namespace simd
} // namespace ynn
46 changes: 0 additions & 46 deletions ynnpack/base/simd/test/x86_avx512bw.cc

This file was deleted.

92 changes: 0 additions & 92 deletions ynnpack/base/simd/test/x86_avx512f.cc

This file was deleted.

4 changes: 2 additions & 2 deletions ynnpack/base/simd/x86_avx512f_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,8 @@ struct vec<int32_t, 16> {

__m512i v;

YNN_ALWAYS_INLINE s16x16 lo() const { return s16x16{internal::lo(v)}; }
YNN_ALWAYS_INLINE s16x16 hi() const { return s16x16{internal::hi(v)}; }
YNN_ALWAYS_INLINE s32x8 lo() const { return s32x8{internal::lo(v)}; }
YNN_ALWAYS_INLINE s32x8 hi() const { return s32x8{internal::hi(v)}; }
};

template <>
Expand Down
13 changes: 4 additions & 9 deletions ynnpack/build_defs.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,10 @@ _YNN_PARAMS_FOR_ARCH = {
"arch_copts": _copts_for_compiler(["-mavx2", "-mfma"]),
"arch_flag": "avx2_fma3",
},
"x86_avx512f": {
"cond": "//ynnpack:ynn_enable_x86_avx512f",
"arch_copts": _copts_for_compiler(["-mavx512f"]),
"arch_flag": "avx512f",
},
"x86_avx512bw": {
"cond": "//ynnpack:ynn_enable_x86_avx512bw",
"arch_copts": _copts_for_compiler(["-mavx512bw"]),
"arch_flag": "avx512bw",
"x86_avx512": {
"cond": "//ynnpack:ynn_enable_x86_avx512",
"arch_copts": _copts_for_compiler(["-mavx512f", "-mavx512bw", "-mavx512vl", "-mavx512dq"]),
"arch_flag": "avx512",
},
"x86_avx512bf16": {
"cond": "//ynnpack:ynn_enable_x86_avx512bf16",
Expand Down
2 changes: 1 addition & 1 deletion ynnpack/kernels/binary/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ ynn_cc_library(
"arm_neon": ["arm_neon.cc"],
"x86_sse2": ["x86_sse2.cc"],
"x86_avx": ["x86_avx.cc"],
"x86_avx512f": ["x86_avx512f.cc"],
"x86_avx512": ["x86_avx512f.cc"],
"x86_avx2": ["x86_avx2.cc"],
},
visibility = ["//ynnpack:__subpackages__"],
Expand Down
4 changes: 2 additions & 2 deletions ynnpack/kernels/binary/kernels.inc
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
// clang-format off

#ifdef YNN_ARCH_X86_AVX512F
#ifdef YNN_ARCH_X86_AVX512
#include "ynnpack/kernels/binary/x86_avx512f.inc"
#endif // YNN_ARCH_X86_AVX512F
#endif // YNN_ARCH_X86_AVX512
#ifdef YNN_ARCH_X86_AVX2
#include "ynnpack/kernels/binary/x86_avx2.inc"
#endif // YNN_ARCH_X86_AVX2
Expand Down
Loading
Loading