Skip to content

Commit

Permalink
[optimize](string) optimize char_length function by SIMD (apache#18925)
Browse files Browse the repository at this point in the history
Optimize char_length function by SIMD
(1) optimize utf8_len compute
(2) 840% up
  • Loading branch information
ZhangYu0123 authored Apr 28, 2023
1 parent aef9355 commit 6626f26
Show file tree
Hide file tree
Showing 3 changed files with 51 additions and 37 deletions.
41 changes: 41 additions & 0 deletions be/src/util/simd/vstring_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,47 @@ class VStringFunctions {
LowerUpperImpl<'a', 'z'> lowerUpper;
lowerUpper.transfer(src, src + len, dst);
}

static inline size_t get_char_len(const char* src, size_t len, std::vector<size_t>& str_index) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < len; i += char_size) {
char_size = UTF8_BYTE_LENGTH[(unsigned char)src[i]];
str_index.push_back(i);
++char_len;
}
return char_len;
}

// utf8-encoding:
// - 1-byte: 0xxx_xxxx;
// - 2-byte: 110x_xxxx 10xx_xxxx;
// - 3-byte: 1110_xxxx 10xx_xxxx 10xx_xxxx;
// - 4-byte: 1111_0xxx 10xx_xxxx 10xx_xxxx 10xx_xxxx.
// Counting utf8 chars in a byte string is equivalent to counting first byte of utf chars, that
// is to say, counting bytes which do not match 10xx_xxxx pattern.
// All 0xxx_xxxx, 110x_xxxx, 1110_xxxx and 1111_0xxx are greater than 1011_1111 when use int8_t arithmetic,
// so just count bytes greater than 1011_1111 in a byte string as the result of utf8_length.
static inline size_t get_char_len(const char* src, size_t len) {
size_t char_len = 0;
const char* p = src;
const char* end = p + len;
#if defined(__SSE2__) || defined(__aarch64__)
constexpr auto bytes_sse2 = sizeof(__m128i);
const auto src_end_sse2 = p + (len & ~(bytes_sse2 - 1));
// threshold = 1011_1111
const auto threshold = _mm_set1_epi8(0xBF);
for (; p < src_end_sse2; p += bytes_sse2) {
char_len += __builtin_popcount(_mm_movemask_epi8(_mm_cmpgt_epi8(
_mm_loadu_si128(reinterpret_cast<const __m128i*>(p)), threshold)));
}
#endif
// process remaining bytes the number of which not exceed bytes_sse2 at the
// tail of string, one by one.
for (; p < end; ++p) {
char_len += static_cast<int8_t>(*p) > static_cast<int8_t>(0xBF);
}
return char_len;
}
};
} // namespace simd
} // namespace doris
8 changes: 5 additions & 3 deletions be/src/vec/functions/function_string.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ struct StringUtf8LengthImpl {
for (int i = 0; i < size; ++i) {
const char* raw_str = reinterpret_cast<const char*>(&data[offsets[i - 1]]);
int str_size = offsets[i] - offsets[i - 1];
res[i] = get_char_len(StringRef(raw_str, str_size), str_size);
res[i] = simd::VStringFunctions::get_char_len(raw_str, str_size);
}
return Status::OK();
}
Expand Down Expand Up @@ -223,7 +223,8 @@ struct StringInStrImpl {
// Hive returns positions starting from 1.
int loc = search.search(&lstr_ref);
if (loc > 0) {
loc = get_char_len(lstr_ref, loc);
size_t len = std::min(lstr_ref.size, (size_t)loc);
loc = simd::VStringFunctions::get_char_len(lstr_ref.data, len);
}
res[i] = loc + 1;
}
Expand Down Expand Up @@ -263,7 +264,8 @@ struct StringInStrImpl {
// Hive returns positions starting from 1.
int loc = search.search(&strl);
if (loc > 0) {
loc = get_char_len(strl, loc);
size_t len = std::min((size_t)loc, strl.size);
loc = simd::VStringFunctions::get_char_len(strl.data, len);
}

return loc + 1;
Expand Down
39 changes: 5 additions & 34 deletions be/src/vec/functions/function_string.h
Original file line number Diff line number Diff line change
Expand Up @@ -102,36 +102,6 @@

namespace doris::vectorized {

//TODO: these three functions could be merged.
inline size_t get_char_len(const std::string_view& str, std::vector<size_t>* str_index) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < str.length(); i += char_size) {
char_size = UTF8_BYTE_LENGTH[(unsigned char)str[i]];
str_index->push_back(i);
++char_len;
}
return char_len;
}

inline size_t get_char_len(const StringRef& str, std::vector<size_t>* str_index) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < str.size; i += char_size) {
char_size = UTF8_BYTE_LENGTH[(unsigned char)(str.data)[i]];
str_index->push_back(i);
++char_len;
}
return char_len;
}

inline size_t get_char_len(const StringRef& str, size_t end_pos) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < std::min(str.size, end_pos); i += char_size) {
char_size = UTF8_BYTE_LENGTH[(unsigned char)(str.data)[i]];
++char_len;
}
return char_len;
}

struct StringOP {
static void push_empty_string(int index, ColumnString::Chars& chars,
ColumnString::Offsets& offsets) {
Expand Down Expand Up @@ -1283,9 +1253,9 @@ class FunctionStringPad : public IFunction {
reinterpret_cast<const char*>(&padcol_chars[padcol_offsets[i - 1]]);

size_t str_char_size =
get_char_len(std::string_view(str_data, str_len), &str_index);
simd::VStringFunctions::get_char_len(str_data, str_len, str_index);
size_t pad_char_size =
get_char_len(std::string_view(pad_data, pad_len), &pad_index);
simd::VStringFunctions::get_char_len(pad_data, pad_len, pad_index);

if (col_len_data[i] <= str_char_size) {
// truncate the input string
Expand Down Expand Up @@ -2430,7 +2400,7 @@ class FunctionStringLocatePos : public IFunction {
// but throws an exception for *start_pos > str->len.
// Since returning 0 seems to be Hive's error condition, return 0.
std::vector<size_t> index;
size_t char_len = get_char_len(str, &index);
size_t char_len = simd::VStringFunctions::get_char_len(str.data, str.size, index);
if (start_pos <= 0 || start_pos > str.size || start_pos > char_len) {
return 0;
}
Expand All @@ -2442,7 +2412,8 @@ class FunctionStringLocatePos : public IFunction {
int32_t match_pos = search_ptr->search(&adjusted_str);
if (match_pos >= 0) {
// Hive returns the position in the original string starting from 1.
return start_pos + get_char_len(adjusted_str, match_pos);
size_t len = std::min(adjusted_str.size, (size_t)match_pos);
return start_pos + simd::VStringFunctions::get_char_len(adjusted_str.data, len);
} else {
return 0;
}
Expand Down

0 comments on commit 6626f26

Please sign in to comment.