Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Enhancement]Support select_if in arm #53093

Merged
merged 5 commits into from
Dec 26, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 158 additions & 0 deletions be/src/simd/selector.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
#ifdef __AVX2__
#include <emmintrin.h>
#include <immintrin.h>
#elif defined(__ARM_NEON__) && defined(__aarch64__)
#include <arm_acle.h>
#include <arm_neon.h>
#endif

#include <cstdint>
Expand Down Expand Up @@ -216,6 +219,145 @@ inline void avx2_select_if_common_implement(uint8_t*& selector, T*& dst, const T
}
}
}

#elif defined(__ARM_NEON) && defined(__aarch64__)
template <class T>
constexpr bool neon_could_use_common_select_if() {
return sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8;
}

template <typename T, bool left_const = false, bool right_const = false>
inline void neon_select_if_common_implement(uint8_t*& selector, T*& dst, const T*& a, const T*& b, int size) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need a const T*& ? can it be simplify as a const T* or even const void* ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why need a const T*& ? can it be simplify as a const T* or even const void* ?

because outside need to handle left data if size % 16 !=0

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can add std::enable_if here to ensure that your template won't be instantiated with unexpected types

const T* dst_end = dst + size;
constexpr int data_size = sizeof(T);
constexpr int neon_width = 16; // NEON register width is 128 bits (16 bytes)

// Process 16 bytes of data at a time
while (dst + neon_width < dst_end) {
// Load 16 selector masks
uint8x16_t loaded_mask = vld1q_u8(selector);
// vceqq_u8: Compare each element in two NEON registers, returns 0xFF if equal, 0x00 if not
loaded_mask = vceqq_u8(loaded_mask, vdupq_n_u8(0));
// vmvnq_u8: Bitwise NOT of each element in NEON register, so non-zero becomes 0xFF, zero becomes 0x00
loaded_mask = vmvnq_u8(loaded_mask);

if constexpr (data_size == 1) { // int8/uint8/bool
// Load vector a
uint8x16_t vec_a;
if constexpr (!left_const) {
vec_a = vld1q_u8(reinterpret_cast<const uint8_t*>(a));
} else {
vec_a = vdupq_n_u8(*reinterpret_cast<const uint8_t*>(a));
}

// Load vector b
uint8x16_t vec_b;
if constexpr (!right_const) {
vec_b = vld1q_u8(reinterpret_cast<const uint8_t*>(b));
} else {
vec_b = vdupq_n_u8(*reinterpret_cast<const uint8_t*>(b));
}

// Select result based on mask
uint8x16_t result = vbslq_u8(loaded_mask, vec_a, vec_b);

// Store result
vst1q_u8(reinterpret_cast<uint8_t*>(dst), result);

} else if constexpr (data_size == 2) { // int16
// Process 2 groups, each handling 8 int16
for (int i = 0; i < 2; i++) {
// Load vector a
uint16x8_t vec_a;
if constexpr (!left_const) {
// vld1q_u16: Load 8 consecutive 16-bit values into NEON register
vec_a = vld1q_u16(reinterpret_cast<const uint16_t*>(a) + i * 8);
} else {
// vdupq_n_u16: Copy a 16-bit value to all elements in the register
vec_a = vdupq_n_u16(*reinterpret_cast<const uint16_t*>(a));
before-Sunrise marked this conversation as resolved.
Show resolved Hide resolved
}

uint16x8_t vec_b;
if constexpr (!right_const) {
vec_b = vld1q_u16(reinterpret_cast<const uint16_t*>(b) + i * 8);
} else {
vec_b = vdupq_n_u16(*reinterpret_cast<const uint16_t*>(b));
}

// Convert first 8 uint8 masks to uint16 masks using lookup table, effectively duplicating each uint8
uint8x16_t index = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7};
uint8x16_t mask = vqtbl1q_u8(loaded_mask, index);

// Select result based on mask
uint16x8_t result = vbslq_u16(vreinterpretq_u16_u8(mask), vec_a, vec_b);

vst1q_u16(reinterpret_cast<uint16_t*>(dst) + i * 8, result);
loaded_mask = vextq_u8(loaded_mask, loaded_mask, 8);
}
} else if constexpr (data_size == 4) { // int32/float
// Process 4 groups, each handling 4 int32
for (int i = 0; i < 4; i++) {
uint32x4_t vec_a;
if constexpr (!left_const) {
vec_a = vld1q_u32(reinterpret_cast<const uint32_t*>(a) + i * 4);
} else {
vec_a = vdupq_n_u32(*reinterpret_cast<const uint32_t*>(a));
}

uint32x4_t vec_b;
if constexpr (!right_const) {
vec_b = vld1q_u32(reinterpret_cast<const uint32_t*>(b) + i * 4);
} else {
vec_b = vdupq_n_u32(*reinterpret_cast<const uint32_t*>(b));
}

uint8x16_t index = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3};
uint8x16_t mask = vqtbl1q_u8(loaded_mask, index);

uint32x4_t result = vbslq_u32(vreinterpretq_u32_u8(mask), vec_a, vec_b);

vst1q_u32(reinterpret_cast<uint32_t*>(dst) + i * 4, result);
loaded_mask = vextq_u8(loaded_mask, loaded_mask, 4);
}
} else if constexpr (data_size == 8) { // int64/double
// Process 8 groups, each handling 2 int64
for (int i = 0; i < 8; i++) {
uint64x2_t vec_a;
if constexpr (!left_const) {
vec_a = vld1q_u64(reinterpret_cast<const uint64_t*>(a) + i * 2);
} else {
vec_a = vdupq_n_u64(*reinterpret_cast<const uint64_t*>(a));
}

uint64x2_t vec_b;
if constexpr (!right_const) {
vec_b = vld1q_u64(reinterpret_cast<const uint64_t*>(b) + i * 2);
} else {
vec_b = vdupq_n_u64(*reinterpret_cast<const uint64_t*>(b));
}

uint8x16_t index = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1};
uint8x16_t mask = vqtbl1q_u8(loaded_mask, index);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this piece of code is almost same exception the data_size parameter, it's better to extract the common part


uint64x2_t result = vbslq_u64(vreinterpretq_u64_u8(mask), vec_a, vec_b);

vst1q_u64(reinterpret_cast<uint64_t*>(dst) + i * 2, result);

loaded_mask = vextq_u8(loaded_mask, loaded_mask, 2);
}
}

dst += 16;
selector += 16;
if (!left_const) {
a += 16;
}
if (!right_const) {
b += 16;
}
}
}

#endif

// SIMD selector
Expand Down Expand Up @@ -245,6 +387,10 @@ class SIMD_selector {
} else if constexpr (could_use_common_select_if<CppType>()) {
avx2_select_if_common_implement(select_vec, start_dst, start_a, start_b, size);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
if constexpr (neon_could_use_common_select_if<CppType>()) {
neon_select_if_common_implement(select_vec, start_dst, start_a, start_b, size);
}
#endif

while (start_dst < end_dst) {
Expand Down Expand Up @@ -272,6 +418,10 @@ class SIMD_selector {
} else if constexpr (could_use_common_select_if<CppType>()) {
avx2_select_if_common_implement<CppType, true, false>(select_vec, start_dst, start_a, start_b, size);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
if constexpr (neon_could_use_common_select_if<CppType>()) {
neon_select_if_common_implement<CppType, true, false>(select_vec, start_dst, start_a, start_b, size);
}
#endif

while (start_dst < end_dst) {
Expand All @@ -298,6 +448,10 @@ class SIMD_selector {
} else if constexpr (could_use_common_select_if<CppType>()) {
avx2_select_if_common_implement<CppType, false, true>(select_vec, start_dst, start_a, start_b, size);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
if constexpr (neon_could_use_common_select_if<CppType>()) {
neon_select_if_common_implement<CppType, false, true>(select_vec, start_dst, start_a, start_b, size);
}
#endif

while (start_dst < end_dst) {
Expand All @@ -324,6 +478,10 @@ class SIMD_selector {
} else if constexpr (could_use_common_select_if<CppType>()) {
avx2_select_if_common_implement<CppType, true, true>(select_vec, start_dst, start_a, start_b, size);
}
#elif defined(__ARM_NEON) && defined(__aarch64__)
if constexpr (neon_could_use_common_select_if<CppType>()) {
neon_select_if_common_implement<CppType, true, true>(select_vec, start_dst, start_a, start_b, size);
}
#endif
while (start_dst < end_dst) {
*start_dst = *select_vec ? a : b;
Expand Down
Loading