Skip to content

Commit

Permalink
[Enhancement] Support split_part function negative index (StarRocks#2…
Browse files Browse the repository at this point in the history
…3931)

Fixes StarRocks#5992

---------

Signed-off-by: leoyy0316 <571684903@qq.com>
  • Loading branch information
leoyy0316 authored May 31, 2023
1 parent 18218ac commit 6b4cb05
Show file tree
Hide file tree
Showing 5 changed files with 271 additions and 57 deletions.
154 changes: 98 additions & 56 deletions be/src/exprs/split_part.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,97 @@

namespace starrocks {

static bool split_index(const Slice& haystack, const Slice& delimiter, int32_t part_number, Slice& res) {
if (part_number > 0) {
if (delimiter.size == 1) {
// if delimiter is a char, use memchr to split
// Record the two adjacent offsets when matching delimiter.
// If no matching, return NULL.
// Else return the string between two adjacent offsets.
int32_t pre_offset = -1;
int32_t offset = -1;
int32_t num = 0;
while (num < part_number) {
pre_offset = offset;
size_t n = haystack.size - offset - 1;
char* pos = reinterpret_cast<char*>(memchr(haystack.data + offset + 1, delimiter.data[0], n));
if (pos != nullptr) {
offset = pos - haystack.data;
num++;
} else {
offset = haystack.size;
num = (num == 0) ? 0 : num + 1;
break;
}
}

if (num == part_number) {
res.data = haystack.data + pre_offset + 1;
res.size = offset - pre_offset - 1;
return true;
}
} else {
// if delimiter is a string, use memmem to split
int32_t pre_offset = -static_cast<int32_t>(delimiter.size);
int32_t offset = -static_cast<int32_t>(delimiter.size);
int32_t num = 0;
while (num < part_number) {
pre_offset = offset;
size_t n = haystack.size - offset - delimiter.size;
char* pos = reinterpret_cast<char*>(
memmem(haystack.data + offset + delimiter.size, n, delimiter.data, delimiter.size));
if (pos != nullptr) {
offset = pos - haystack.data;
num++;
} else {
offset = haystack.size;
num = (num == 0) ? 0 : num + 1;
break;
}
}

if (num == part_number) {
res.data = haystack.data + pre_offset + delimiter.size;
res.size = offset - pre_offset - delimiter.size;
return true;
}
}
} else {
part_number = -part_number;
auto haystack_str = haystack.to_string();
int32_t offset = haystack.size;
int32_t pre_offset = offset;
int32_t num = 0;
auto substr = haystack_str;
while (num <= part_number && offset >= 0) {
// TODO benchmarking rfind vs memrchr.
offset = (int)substr.rfind(delimiter, offset);
if (offset != -1) {
if (++num == part_number) {
break;
}
pre_offset = offset;
offset = offset - 1;
substr = haystack_str.substr(0, pre_offset);
} else {
break;
}
}
num = (offset == -1 && num != 0) ? num + 1 : num;
if (num == part_number) {
if (offset == -1) {
res.data = haystack.data;
res.size = pre_offset;
} else {
res.data = haystack.data + offset + delimiter.size;
res.size = pre_offset - offset - delimiter.size;
}
return true;
}
}
return false;
}

/**
* @param: [haystack, delimiter, part_number]
* @paramType: [BinaryColumn, BinaryColumn, IntColumn]
Expand All @@ -33,10 +124,11 @@ StatusOr<ColumnPtr> StringFunctions::split_part(FunctionContext* context, const
DCHECK_EQ(columns.size(), 3);
RETURN_IF_COLUMNS_ONLY_NULL(columns);

// TODO use SIMD algorithm to optimize
if (columns[2]->is_constant()) {
// if part_number is a negative int, return NULL.
// if part_number is 0, return NULL.
int32_t part_number = ColumnHelper::get_const_value<TYPE_INT>(columns[2]);
if (part_number <= 0) {
if (part_number == 0) {
return ColumnHelper::create_const_null_column(columns[0]->size());
}
}
Expand All @@ -47,18 +139,14 @@ StatusOr<ColumnPtr> StringFunctions::split_part(FunctionContext* context, const

size_t size = columns[0]->size();
ColumnBuilder<TYPE_VARCHAR> res(size);
Slice slice;
for (int i = 0; i < size; ++i) {
if (haystack_viewer.is_null(i) || delimiter_viewer.is_null(i) || part_number_viewer.is_null(i)) {
res.append_null();
continue;
}

int32_t part_number = part_number_viewer.value(i);
if (part_number <= 0) {
res.append_null();
continue;
}

Slice haystack = haystack_viewer.value(i);
Slice delimiter = delimiter_viewer.value(i);
if (delimiter.size == 0) {
Expand All @@ -78,55 +166,9 @@ StatusOr<ColumnPtr> StringFunctions::split_part(FunctionContext* context, const
res.append(Slice(haystack.data + h, char_size));
}
}
} else if (delimiter.size == 1) {
// if delimiter is a char, use memchr to split
// Record the two adjacent offsets when matching delimiter.
// If no matching, return NULL.
// Else return the string between two adjacent offsets.
int32_t pre_offset = -1;
int32_t offset = -1;
int32_t num = 0;
while (num < part_number) {
pre_offset = offset;
size_t n = haystack.size - offset - 1;
char* pos = reinterpret_cast<char*>(memchr(haystack.data + offset + 1, delimiter.data[0], n));
if (pos != nullptr) {
offset = pos - haystack.data;
num++;
} else {
offset = haystack.size;
num = (num == 0) ? 0 : num + 1;
break;
}
}

if (num == part_number) {
res.append(Slice(haystack.data + pre_offset + 1, offset - pre_offset - 1));
} else {
res.append_null();
}
} else {
// if delimiter is a string, use memmem to split
int32_t pre_offset = -static_cast<int32_t>(delimiter.size);
int32_t offset = -static_cast<int32_t>(delimiter.size);
int32_t num = 0;
while (num < part_number) {
pre_offset = offset;
size_t n = haystack.size - offset - delimiter.size;
char* pos = reinterpret_cast<char*>(
memmem(haystack.data + offset + delimiter.size, n, delimiter.data, delimiter.size));
if (pos != nullptr) {
offset = pos - haystack.data;
num++;
} else {
offset = haystack.size;
num = (num == 0) ? 0 : num + 1;
break;
}
}

if (num == part_number) {
res.append(Slice(haystack.data + pre_offset + delimiter.size, offset - pre_offset - delimiter.size));
if (split_index(haystack, delimiter, part_number, slice)) {
res.append(slice);
} else {
res.append_null();
}
Expand All @@ -135,4 +177,4 @@ StatusOr<ColumnPtr> StringFunctions::split_part(FunctionContext* context, const
return res.build(ColumnHelper::is_all_const(columns));
}

} // namespace starrocks
} // namespace starrocks
24 changes: 24 additions & 0 deletions be/test/exprs/string_fn_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -934,6 +934,26 @@ PARALLEL_TEST(VecStringFunctionsTest, splitPart) {
delim->append("");
field->append(10);

// 25
str->append("hello word");
delim->append(" ");
field->append(-1);

// 26
str->append("hello word");
delim->append(" ");
field->append(-2);

// 27
str->append("hello word");
delim->append(" ");
field->append(-3);

// 28
str->append("2019年9月8日");
delim->append("");
field->append(-1);

columns.emplace_back(str);
columns.emplace_back(delim);
columns.emplace_back(field);
Expand Down Expand Up @@ -966,6 +986,10 @@ PARALLEL_TEST(VecStringFunctionsTest, splitPart) {
ASSERT_EQ("9", v->get(22).get<Slice>().to_string());
ASSERT_EQ("", v->get(23).get<Slice>().to_string());
ASSERT_TRUE(v->get(24).is_null());
ASSERT_EQ("word", v->get(25).get<Slice>().to_string());
ASSERT_EQ("hello", v->get(26).get<Slice>().to_string());
ASSERT_TRUE(v->get(27).is_null());
ASSERT_EQ("8日", v->get(28).get<Slice>().to_string());
}

PARALLEL_TEST(VecStringFunctionsTest, leftTest) {
Expand Down
28 changes: 28 additions & 0 deletions docs/sql-reference/sql-functions/string-functions/split_part.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,40 @@ MySQL > select split_part("hello world", " ", 2);
| world |
+-----------------------------------+
MySQL > select split_part("hello world", " ", -1);
+----------------------------------+
|split_part('hello world', ' ', -1) |
+----------------------------------+
| world |
+----------------------------------+
MySQL > select split_part("hello world", " ", -2);
+-----------------------------------+
| split_part('hello world', ' ', -2) |
+-----------------------------------+
| hello |
+-----------------------------------+
MySQL > select split_part("abca", "a", 1);
+----------------------------+
| split_part('abca', 'a', 1) |
+----------------------------+
| |
+----------------------------+
select split_part("abca", "a", -1);
+-----------------------------+
| split_part('abca', 'a', -1) |
+-----------------------------+
| |
+-----------------------------+
select split_part("abca", "a", -2);
+-----------------------------+
| split_part('abca', 'a', -2) |
+-----------------------------+
| bc |
+-----------------------------+
```

## keyword
Expand Down
90 changes: 90 additions & 0 deletions test/sql/test_string_functions/R/test_string_functions
Original file line number Diff line number Diff line change
Expand Up @@ -144,4 +144,94 @@ select rpad(c0, c1) from t0;
test
te

-- !result
select split_part("hello world", " ", 1);
-- result:
hello
-- !result
select split_part("hello world", " ", 2);
-- result:
world
-- !result
select split_part("hello world", " ", -1);
-- result:
world
-- !result
select split_part("hello world", " ", -2);
-- result:
hello
-- !result
select split_part("2023年5月23号", "月", 1);
-- result:
2023年5
-- !result
select split_part("2023年5月23号", "月", -1);
-- result:
23号
-- !result
select split_part("abc##567###234", "##", 1);
-- result:
abc
-- !result
select split_part("abc##567###234", "##", 2);
-- result:
567
-- !result
select split_part("abc##567###234", "##", -1);
-- result:
234
-- !result
select split_part("abc##567###234", "##", -2);
-- result:
567#
-- !result
create table t1(c0 varchar(20), c1 varchar(20))
DUPLICATE KEY(c0)
DISTRIBUTED BY HASH(c0)
BUCKETS 1
PROPERTIES('replication_num'='1');
-- result:
-- !result
insert into t1 values ('hello world', 'abc##567###234');
-- result:
-- !result
select split_part(c0, " ", 1) from t1;
-- result:
hello
-- !result
select split_part(c0, " ", 2) from t1;
-- result:
world
-- !result
select split_part(c0, " ", -1) from t1;
-- result:
world
-- !result
select split_part(c0, " ", -2) from t1;
-- result:
hello
-- !result
select split_part(c1, "##", 1) from t1;
-- result:
abc
-- !result
select split_part(c1, "##", 2) from t1;
-- result:
567
-- !result
select split_part(c1, "##", 3) from t1;
-- result:
#234
-- !result
select split_part(c1, "##", -1) from t1;
-- result:
234
-- !result
select split_part(c1, "##", -2) from t1;
-- result:
567#
-- !result
select split_part(c1, "##", -3) from t1;
-- result:
abc
-- !result
Loading

0 comments on commit 6b4cb05

Please sign in to comment.