Skip to content

Commit d77e272

Browse files
lidavidmpitrou
authored andcommitted
ARROW-12950: [C++] Add count_substring kernel
Depends on ARROW-12969. ignore_case is not included here; I'll include it with the regex variant in ARROW-12952. Closes #10454 from lidavidm/arrow-12950 Authored-by: David Li <li.davidm96@gmail.com> Signed-off-by: Antoine Pitrou <antoine@python.org>
1 parent 80fe83a commit d77e272

File tree

6 files changed

+139
-13
lines changed

6 files changed

+139
-13
lines changed

cpp/src/arrow/compute/kernels/scalar_string.cc

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,8 +741,12 @@ template <typename InputType>
741741
struct FindSubstringExec {
742742
using OffsetType = typename TypeTraits<InputType>::OffsetType;
743743
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
744+
const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
745+
if (options.ignore_case) {
746+
return Status::NotImplemented("find_substring with ignore_case");
747+
}
744748
applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, FindSubstring> kernel{
745-
FindSubstring(PlainSubstringMatcher(MatchSubstringState::Get(ctx)))};
749+
FindSubstring(PlainSubstringMatcher(options))};
746750
return kernel.Exec(ctx, batch, out);
747751
}
748752
};
@@ -771,6 +775,69 @@ void AddFindSubstring(FunctionRegistry* registry) {
771775
DCHECK_OK(registry->AddFunction(std::move(func)));
772776
}
773777

778+
// Substring count
779+
780+
struct CountSubstring {
781+
const PlainSubstringMatcher matcher_;
782+
783+
explicit CountSubstring(PlainSubstringMatcher matcher) : matcher_(std::move(matcher)) {}
784+
785+
template <typename OutValue, typename... Ignored>
786+
OutValue Call(KernelContext*, util::string_view val, Status*) const {
787+
OutValue count = 0;
788+
uint64_t start = 0;
789+
const auto pattern_size = std::max<uint64_t>(1, matcher_.options_.pattern.size());
790+
while (start <= val.size()) {
791+
const int64_t index = matcher_.Find(val.substr(start));
792+
if (index >= 0) {
793+
count++;
794+
start += index + pattern_size;
795+
} else {
796+
break;
797+
}
798+
}
799+
return count;
800+
}
801+
};
802+
803+
template <typename InputType>
804+
struct CountSubstringExec {
805+
using OffsetType = typename TypeTraits<InputType>::OffsetType;
806+
static Status Exec(KernelContext* ctx, const ExecBatch& batch, Datum* out) {
807+
const MatchSubstringOptions& options = MatchSubstringState::Get(ctx);
808+
if (options.ignore_case) {
809+
return Status::NotImplemented("count_substring with ignore_case");
810+
}
811+
applicator::ScalarUnaryNotNullStateful<OffsetType, InputType, CountSubstring> kernel{
812+
CountSubstring(PlainSubstringMatcher(options))};
813+
return kernel.Exec(ctx, batch, out);
814+
}
815+
};
816+
817+
const FunctionDoc count_substring_doc(
818+
"Count occurrences of substring",
819+
("For each string in `strings`, emit the number of occurrences of the given "
820+
"pattern.\n"
821+
"Null inputs emit null. The pattern must be given in MatchSubstringOptions."),
822+
{"strings"}, "MatchSubstringOptions");
823+
824+
void AddCountSubstring(FunctionRegistry* registry) {
825+
auto func = std::make_shared<ScalarFunction>("count_substring", Arity::Unary(),
826+
&count_substring_doc);
827+
for (const auto& ty : BaseBinaryTypes()) {
828+
std::shared_ptr<DataType> offset_type;
829+
if (ty->id() == Type::type::LARGE_BINARY || ty->id() == Type::type::LARGE_STRING) {
830+
offset_type = int64();
831+
} else {
832+
offset_type = int32();
833+
}
834+
DCHECK_OK(func->AddKernel({ty}, offset_type,
835+
GenerateTypeAgnosticVarBinaryBase<CountSubstringExec>(ty),
836+
MatchSubstringState::Init));
837+
}
838+
DCHECK_OK(registry->AddFunction(std::move(func)));
839+
}
840+
774841
// Slicing
775842

776843
template <typename Type, typename Derived>
@@ -3213,6 +3280,7 @@ void RegisterScalarStringAscii(FunctionRegistry* registry) {
32133280
AddUtf8Length(registry);
32143281
AddMatchSubstring(registry);
32153282
AddFindSubstring(registry);
3283+
AddCountSubstring(registry);
32163284
MakeUnaryStringBatchKernelWithState<ReplaceSubStringPlain>(
32173285
"replace_substring", registry, &replace_substring_doc,
32183286
MemAllocation::NO_PREALLOCATE);

cpp/src/arrow/compute/kernels/scalar_string_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,25 @@ TYPED_TEST(TestBinaryKernels, FindSubstring) {
103103
"[0, 0, null]", &options_empty);
104104
}
105105

106+
TYPED_TEST(TestBinaryKernels, CountSubstring) {
107+
MatchSubstringOptions options{"aba"};
108+
this->CheckUnary("count_substring", "[]", this->offset_type(), "[]", &options);
109+
this->CheckUnary(
110+
"count_substring",
111+
R"(["", null, "ab", "aba", "baba", "ababa", "abaaba", "babacaba", "ABA"])",
112+
this->offset_type(), "[0, null, 0, 1, 1, 1, 2, 2, 0]", &options);
113+
114+
MatchSubstringOptions options_empty{""};
115+
this->CheckUnary("count_substring", R"(["", null, "abc"])", this->offset_type(),
116+
"[1, null, 4]", &options_empty);
117+
118+
MatchSubstringOptions options_repeated{"aaa"};
119+
this->CheckUnary("count_substring", R"(["", "aaaa", "aaaaa", "aaaaaa", "aaá"])",
120+
this->offset_type(), "[0, 1, 1, 2, 0]", &options_repeated);
121+
122+
// TODO: case-insensitive
123+
}
124+
106125
template <typename TestType>
107126
class TestStringKernels : public BaseTestStringKernels<TestType> {};
108127

docs/source/cpp/compute.rst

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -561,45 +561,51 @@ Containment tests
561561
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
562562
| Function name | Arity | Input types | Output type | Options class |
563563
+===========================+============+====================================+====================+========================================+
564-
| find_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
564+
| count_substring | Unary | String-like | Int32 or Int64 (1) | :struct:`MatchSubstringOptions` |
565565
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
566-
| match_like | Unary | String-like | Boolean (2) | :struct:`MatchSubstringOptions` |
566+
| find_substring | Unary | String-like | Int32 or Int64 (2) | :struct:`MatchSubstringOptions` |
567567
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
568-
| match_substring | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` |
568+
| match_like | Unary | String-like | Boolean (3) | :struct:`MatchSubstringOptions` |
569569
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
570-
| match_substring_regex | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` |
570+
| match_substring | Unary | String-like | Boolean (4) | :struct:`MatchSubstringOptions` |
571571
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
572-
| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (5) | :struct:`SetLookupOptions` |
572+
| match_substring_regex | Unary | String-like | Boolean (5) | :struct:`MatchSubstringOptions` |
573+
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
574+
| index_in | Unary | Boolean, Null, Numeric, Temporal, | Int32 (6) | :struct:`SetLookupOptions` |
573575
| | | Binary- and String-like | | |
574576
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
575-
| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (6) | :struct:`SetLookupOptions` |
577+
| is_in | Unary | Boolean, Null, Numeric, Temporal, | Boolean (7) | :struct:`SetLookupOptions` |
576578
| | | Binary- and String-like | | |
577579
+---------------------------+------------+------------------------------------+--------------------+----------------------------------------+
578580

581+
* \(1) Output is the number of occurrences of
582+
:member:`MatchSubstringOptions::pattern` in the corresponding input
583+
string. Output type is Int32 for Binary/String, Int64
584+
for LargeBinary/LargeString.
579585

580-
* \(1) Output is the index of the first occurrence of
586+
* \(2) Output is the index of the first occurrence of
581587
:member:`MatchSubstringOptions::pattern` in the corresponding input
582588
string, otherwise -1. Output type is Int32 for Binary/String, Int64
583589
for LargeBinary/LargeString.
584590

585-
* \(2) Output is true iff the SQL-style LIKE pattern
591+
* \(3) Output is true iff the SQL-style LIKE pattern
586592
:member:`MatchSubstringOptions::pattern` fully matches the
587593
corresponding input element. That is, ``%`` will match any number of
588594
characters, ``_`` will match exactly one character, and any other
589595
character matches itself. To match a literal percent sign or
590596
underscore, precede the character with a backslash.
591597

592-
* \(3) Output is true iff :member:`MatchSubstringOptions::pattern`
598+
* \(4) Output is true iff :member:`MatchSubstringOptions::pattern`
593599
is a substring of the corresponding input element.
594600

595-
* \(4) Output is true iff :member:`MatchSubstringOptions::pattern`
601+
* \(5) Output is true iff :member:`MatchSubstringOptions::pattern`
596602
matches the corresponding input element at any position.
597603

598-
* \(5) Output is the index of the corresponding input element in
604+
* \(6) Output is the index of the corresponding input element in
599605
:member:`SetLookupOptions::value_set`, if found there. Otherwise,
600606
output is null.
601607

602-
* \(6) Output is true iff the corresponding input element is equal to one
608+
* \(7) Output is true iff the corresponding input element is equal to one
603609
of the elements in :member:`SetLookupOptions::value_set`.
604610

605611

docs/source/python/api/compute.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ Containment tests
178178
.. autosummary::
179179
:toctree: ../generated/
180180

181+
count_substring
181182
find_substring
182183
index_in
183184
is_in

python/pyarrow/compute.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -291,6 +291,25 @@ def cast(arr, target_type, safe=True):
291291
return call_function("cast", [arr], options)
292292

293293

294+
def count_substring(array, pattern):
295+
"""
296+
Count the occurrences of substring *pattern* in each value of a
297+
string array.
298+
299+
Parameters
300+
----------
301+
array : pyarrow.Array or pyarrow.ChunkedArray
302+
pattern : str
303+
pattern to search for exact matches
304+
305+
Returns
306+
-------
307+
result : pyarrow.Array or pyarrow.ChunkedArray
308+
"""
309+
return call_function("count_substring", [array],
310+
MatchSubstringOptions(pattern))
311+
312+
294313
def find_substring(array, pattern):
295314
"""
296315
Find the index of the first occurrence of substring *pattern* in each

python/pyarrow/tests/test_compute.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,19 @@ def test_variance():
285285
assert pc.variance(data, ddof=1).as_py() == 6.0
286286

287287

288+
def test_count_substring():
289+
arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None])
290+
result = pc.count_substring(arr, "ab")
291+
expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int32())
292+
assert expected.equals(result)
293+
294+
arr = pa.array(["ab", "cab", "abcab", "ba", "AB", None],
295+
type=pa.large_string())
296+
result = pc.count_substring(arr, "ab")
297+
expected = pa.array([1, 1, 2, 0, 0, None], type=pa.int64())
298+
assert expected.equals(result)
299+
300+
288301
def test_find_substring():
289302
arr = pa.array(["ab", "cab", "ba", None])
290303
result = pc.find_substring(arr, "ab")

0 commit comments

Comments
 (0)