-
Notifications
You must be signed in to change notification settings - Fork 1.9k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Enhancement]Support select_if in arm #53093
Merged
silverbullet233
merged 5 commits into
StarRocks:main
from
before-Sunrise:support_select_if_arm
Dec 26, 2024
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,6 +17,9 @@ | |
#ifdef __AVX2__ | ||
#include <emmintrin.h> | ||
#include <immintrin.h> | ||
#elif defined(__ARM_NEON__) && defined(__aarch64__) | ||
#include <arm_acle.h> | ||
#include <arm_neon.h> | ||
#endif | ||
|
||
#include <cstdint> | ||
|
@@ -216,6 +219,145 @@ inline void avx2_select_if_common_implement(uint8_t*& selector, T*& dst, const T | |
} | ||
} | ||
} | ||
|
||
#elif defined(__ARM_NEON) && defined(__aarch64__) | ||
template <class T> | ||
constexpr bool neon_could_use_common_select_if() { | ||
return sizeof(T) == 1 || sizeof(T) == 2 || sizeof(T) == 4 || sizeof(T) == 8; | ||
} | ||
|
||
template <typename T, bool left_const = false, bool right_const = false> | ||
inline void neon_select_if_common_implement(uint8_t*& selector, T*& dst, const T*& a, const T*& b, int size) { | ||
const T* dst_end = dst + size; | ||
constexpr int data_size = sizeof(T); | ||
constexpr int neon_width = 16; // NEON register width is 128 bits (16 bytes) | ||
|
||
// Process 16 bytes of data at a time | ||
while (dst + neon_width < dst_end) { | ||
// Load 16 selector masks | ||
uint8x16_t loaded_mask = vld1q_u8(selector); | ||
// vceqq_u8: Compare each element in two NEON registers, returns 0xFF if equal, 0x00 if not | ||
loaded_mask = vceqq_u8(loaded_mask, vdupq_n_u8(0)); | ||
// vmvnq_u8: Bitwise NOT of each element in NEON register, so non-zero becomes 0xFF, zero becomes 0x00 | ||
loaded_mask = vmvnq_u8(loaded_mask); | ||
|
||
if constexpr (data_size == 1) { // int8/uint8/bool | ||
// Load vector a | ||
uint8x16_t vec_a; | ||
if constexpr (!left_const) { | ||
vec_a = vld1q_u8(reinterpret_cast<const uint8_t*>(a)); | ||
} else { | ||
vec_a = vdupq_n_u8(*reinterpret_cast<const uint8_t*>(a)); | ||
} | ||
|
||
// Load vector b | ||
uint8x16_t vec_b; | ||
if constexpr (!right_const) { | ||
vec_b = vld1q_u8(reinterpret_cast<const uint8_t*>(b)); | ||
} else { | ||
vec_b = vdupq_n_u8(*reinterpret_cast<const uint8_t*>(b)); | ||
} | ||
|
||
// Select result based on mask | ||
uint8x16_t result = vbslq_u8(loaded_mask, vec_a, vec_b); | ||
|
||
// Store result | ||
vst1q_u8(reinterpret_cast<uint8_t*>(dst), result); | ||
|
||
} else if constexpr (data_size == 2) { // int16 | ||
// Process 2 groups, each handling 8 int16 | ||
for (int i = 0; i < 2; i++) { | ||
// Load vector a | ||
uint16x8_t vec_a; | ||
if constexpr (!left_const) { | ||
// vld1q_u16: Load 8 consecutive 16-bit values into NEON register | ||
vec_a = vld1q_u16(reinterpret_cast<const uint16_t*>(a) + i * 8); | ||
} else { | ||
// vdupq_n_u16: Copy a 16-bit value to all elements in the register | ||
vec_a = vdupq_n_u16(*reinterpret_cast<const uint16_t*>(a)); | ||
before-Sunrise marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
|
||
uint16x8_t vec_b; | ||
if constexpr (!right_const) { | ||
vec_b = vld1q_u16(reinterpret_cast<const uint16_t*>(b) + i * 8); | ||
} else { | ||
vec_b = vdupq_n_u16(*reinterpret_cast<const uint16_t*>(b)); | ||
} | ||
|
||
// Convert first 8 uint8 masks to uint16 masks using lookup table, effectively duplicating each uint8 | ||
uint8x16_t index = {0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7}; | ||
uint8x16_t mask = vqtbl1q_u8(loaded_mask, index); | ||
|
||
// Select result based on mask | ||
uint16x8_t result = vbslq_u16(vreinterpretq_u16_u8(mask), vec_a, vec_b); | ||
|
||
vst1q_u16(reinterpret_cast<uint16_t*>(dst) + i * 8, result); | ||
loaded_mask = vextq_u8(loaded_mask, loaded_mask, 8); | ||
} | ||
} else if constexpr (data_size == 4) { // int32/float | ||
// Process 4 groups, each handling 4 int32 | ||
for (int i = 0; i < 4; i++) { | ||
uint32x4_t vec_a; | ||
if constexpr (!left_const) { | ||
vec_a = vld1q_u32(reinterpret_cast<const uint32_t*>(a) + i * 4); | ||
} else { | ||
vec_a = vdupq_n_u32(*reinterpret_cast<const uint32_t*>(a)); | ||
} | ||
|
||
uint32x4_t vec_b; | ||
if constexpr (!right_const) { | ||
vec_b = vld1q_u32(reinterpret_cast<const uint32_t*>(b) + i * 4); | ||
} else { | ||
vec_b = vdupq_n_u32(*reinterpret_cast<const uint32_t*>(b)); | ||
} | ||
|
||
uint8x16_t index = {0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3}; | ||
uint8x16_t mask = vqtbl1q_u8(loaded_mask, index); | ||
|
||
uint32x4_t result = vbslq_u32(vreinterpretq_u32_u8(mask), vec_a, vec_b); | ||
|
||
vst1q_u32(reinterpret_cast<uint32_t*>(dst) + i * 4, result); | ||
loaded_mask = vextq_u8(loaded_mask, loaded_mask, 4); | ||
} | ||
} else if constexpr (data_size == 8) { // int64/double | ||
// Process 8 groups, each handling 2 int64 | ||
for (int i = 0; i < 8; i++) { | ||
uint64x2_t vec_a; | ||
if constexpr (!left_const) { | ||
vec_a = vld1q_u64(reinterpret_cast<const uint64_t*>(a) + i * 2); | ||
} else { | ||
vec_a = vdupq_n_u64(*reinterpret_cast<const uint64_t*>(a)); | ||
} | ||
|
||
uint64x2_t vec_b; | ||
if constexpr (!right_const) { | ||
vec_b = vld1q_u64(reinterpret_cast<const uint64_t*>(b) + i * 2); | ||
} else { | ||
vec_b = vdupq_n_u64(*reinterpret_cast<const uint64_t*>(b)); | ||
} | ||
|
||
uint8x16_t index = {0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1}; | ||
uint8x16_t mask = vqtbl1q_u8(loaded_mask, index); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this piece of code is almost same exception the data_size parameter, it's better to extract the common part |
||
|
||
uint64x2_t result = vbslq_u64(vreinterpretq_u64_u8(mask), vec_a, vec_b); | ||
|
||
vst1q_u64(reinterpret_cast<uint64_t*>(dst) + i * 2, result); | ||
|
||
loaded_mask = vextq_u8(loaded_mask, loaded_mask, 2); | ||
} | ||
} | ||
|
||
dst += 16; | ||
selector += 16; | ||
if (!left_const) { | ||
a += 16; | ||
} | ||
if (!right_const) { | ||
b += 16; | ||
} | ||
} | ||
} | ||
|
||
#endif | ||
|
||
// SIMD selector | ||
|
@@ -245,6 +387,10 @@ class SIMD_selector { | |
} else if constexpr (could_use_common_select_if<CppType>()) { | ||
avx2_select_if_common_implement(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#elif defined(__ARM_NEON) && defined(__aarch64__) | ||
if constexpr (neon_could_use_common_select_if<CppType>()) { | ||
neon_select_if_common_implement(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#endif | ||
|
||
while (start_dst < end_dst) { | ||
|
@@ -272,6 +418,10 @@ class SIMD_selector { | |
} else if constexpr (could_use_common_select_if<CppType>()) { | ||
avx2_select_if_common_implement<CppType, true, false>(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#elif defined(__ARM_NEON) && defined(__aarch64__) | ||
if constexpr (neon_could_use_common_select_if<CppType>()) { | ||
neon_select_if_common_implement<CppType, true, false>(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#endif | ||
|
||
while (start_dst < end_dst) { | ||
|
@@ -298,6 +448,10 @@ class SIMD_selector { | |
} else if constexpr (could_use_common_select_if<CppType>()) { | ||
avx2_select_if_common_implement<CppType, false, true>(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#elif defined(__ARM_NEON) && defined(__aarch64__) | ||
if constexpr (neon_could_use_common_select_if<CppType>()) { | ||
neon_select_if_common_implement<CppType, false, true>(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#endif | ||
|
||
while (start_dst < end_dst) { | ||
|
@@ -324,6 +478,10 @@ class SIMD_selector { | |
} else if constexpr (could_use_common_select_if<CppType>()) { | ||
avx2_select_if_common_implement<CppType, true, true>(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#elif defined(__ARM_NEON) && defined(__aarch64__) | ||
if constexpr (neon_could_use_common_select_if<CppType>()) { | ||
neon_select_if_common_implement<CppType, true, true>(select_vec, start_dst, start_a, start_b, size); | ||
} | ||
#endif | ||
while (start_dst < end_dst) { | ||
*start_dst = *select_vec ? a : b; | ||
|
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why need a
const T*&
? can it be simplify as aconst T*
or evenconst void*
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
because outside need to handle left data if size % 16 !=0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you can add std::enable_if here to ensure that your template won't be instantiated with unexpected types