Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BUG (string): ArrowStringArray.find corner cases #59562

Merged
merged 10 commits into from
Sep 6, 2024
44 changes: 43 additions & 1 deletion pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Literal,
)

import numpy as np

from pandas.compat import (
pa_version_under10p1,
pa_version_under13p0,
pa_version_under17p0,
)

Expand All @@ -20,7 +22,10 @@
import pyarrow.compute as pc

if TYPE_CHECKING:
from collections.abc import Sized
from collections.abc import (
Callable,
Sized,
)

from pandas._typing import (
Scalar,
Expand All @@ -42,6 +47,9 @@ def _convert_int_result(self, result):
# Convert an integer-dtype result to the appropriate result type
raise NotImplementedError

def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This needs to be implemented for ArrowStringArray as well then?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is only used for the ArrowEA version. The ArrowStringArray goes through _str_map, which ArrowEA doesn't have. eventually id like to align the names, but there are too many branches/PRs as it is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what I am missing, but _apply_elementwise is called from the now-shared _str_find method just below, and so I would think that you can also get there from ArrowStringArray._str_find ?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep my bad. ArrowStringArray inherits ArrowEA so gets its apply_elementwise from there. putting it here just prevents mypy from complaining

raise NotImplementedError

def _str_pad(
self,
width: int,
Expand Down Expand Up @@ -205,3 +213,37 @@ def _str_contains(
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if (
pa_version_under13p0
and not (start != 0 and end is not None)
and not (start == 0 and end is None)
):
# GH#59562
res_list = self._apply_elementwise(lambda val: val.find(sub, start, end))
return self._convert_int_result(pa.chunked_array(res_list))

if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
else:
if sub == "":
# GH#56792
res_list = self._apply_elementwise(
lambda val: val.find(sub, start, end)
)
return self._convert_int_result(pa.chunked_array(res_list))
if start is None:
start_offset = 0
start = 0
elif start < 0:
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
else:
start_offset = start
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
offset_result = pc.add(result, start_offset)
result = pc.if_else(found, offset_result, -1)
return self._convert_int_result(result)
23 changes: 0 additions & 23 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,29 +2373,6 @@ def _str_fullmatch(
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
else:
if sub == "":
# GH 56792
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
return type(self)(pa.chunked_array(result))
if start is None:
start_offset = 0
start = 0
elif start < 0:
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
else:
start_offset = start
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
offset_result = pc.add(result, start_offset)
result = pc.if_else(found, offset_result, -1)
return type(self)(result)

def _str_join(self, sep: str) -> Self:
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
self._pa_array.type
Expand Down
18 changes: 7 additions & 11 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,18 +416,14 @@ def _str_count(self, pat: str, flags: int = 0):
return self._convert_int_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
if (
pa_version_under13p0
and not (start != 0 and end is not None)
and not (start == 0 and end is None)
):
# GH#59562
return super()._str_find(sub, start, end)
return self._convert_int_result(result)
return ArrowStringArrayMixin._str_find(self, sub, start, end)
Comment on lines +419 to +426
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that this special case is moved in the mixin method, I would expect this can be removed entirely? (and replaced with a _str_find = ArrowStringArrayMixin._str_find)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this goes through a cython path instead of iterating in python

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, through _str_map using lib.map_infer_mask, I suppose. But if there is a cython implementation that is presumably faster, shouldn't we use that for the ArrowDtype as well?
I saw that in the center PR at https://github.com/pandas-dev/pandas/pull/59624/files#diff-ca6e5560b2fc1721e129b85f10882df8a1f20b5f1ef4dff547170fa35898dfa6R62 you didn't use _apply_elementwise but also explicitly went through object dtype. That's for the same reason? Can we use the same pattern?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, changed.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks like doing this broke the min_versions build, so reverted


def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
Expand Down
31 changes: 11 additions & 20 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas._libs.tslibs import timezones
from pandas.compat import (
Expand Down Expand Up @@ -1947,14 +1945,9 @@ def test_str_find_negative_start():

def test_str_find_no_end():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
if pa_version_under13p0:
# https://github.com/apache/arrow/issues/36311
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
ser.str.find("ab", start=1)
else:
result = ser.str.find("ab", start=1)
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
tm.assert_series_equal(result, expected)
result = ser.str.find("ab", start=1)
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
tm.assert_series_equal(result, expected)


def test_str_find_negative_start_negative_end():
Expand All @@ -1968,17 +1961,11 @@ def test_str_find_negative_start_negative_end():
def test_str_find_large_start():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
if pa_version_under13p0:
# https://github.com/apache/arrow/issues/36311
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
ser.str.find(sub="d", start=16)
else:
result = ser.str.find(sub="d", start=16)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)
result = ser.str.find(sub="d", start=16)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.skipif(
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
)
Expand All @@ -1990,11 +1977,15 @@ def test_str_find_e2e(start, end, sub):
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
dtype=ArrowDtype(pa.string()),
)
object_series = s.astype(pd.StringDtype())
object_series = s.astype(pd.StringDtype(storage="python"))
result = s.str.find(sub, start, end)
expected = object_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result, expected)

arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result2, expected)
Comment on lines +1985 to +1987
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For future PRs, we should add such tests to pandas/tests/strings, I think (because now it is testing StringDtype in tests specifically for ArrowDtype ..)



def test_str_find_negative_start_negative_end_no_match():
# GH 56791
Expand Down
Loading