Skip to content

Commit

Permalink
BUG (string): ArrowStringArray.find corner cases (pandas-dev#59562)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored and jorisvandenbossche committed Oct 10, 2024
1 parent 553780a commit 44325c1
Show file tree
Hide file tree
Showing 4 changed files with 99 additions and 32 deletions.
44 changes: 43 additions & 1 deletion pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Literal,
)

import numpy as np

from pandas.compat import (
pa_version_under10p1,
pa_version_under13p0,
pa_version_under17p0,
)

Expand All @@ -20,7 +22,10 @@
import pyarrow.compute as pc

if TYPE_CHECKING:
from collections.abc import Sized
from collections.abc import (
Callable,
Sized,
)

from pandas._typing import Scalar

Expand All @@ -39,6 +44,9 @@ def _convert_int_result(self, result):
# Convert an integer-dtype result to the appropriate result type
raise NotImplementedError

def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
raise NotImplementedError

def _str_pad(
self,
width: int,
Expand Down Expand Up @@ -201,3 +209,37 @@ def _str_contains(
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if (
pa_version_under13p0
and not (start != 0 and end is not None)
and not (start == 0 and end is None)
):
# GH#59562
res_list = self._apply_elementwise(lambda val: val.find(sub, start, end))
return self._convert_int_result(pa.chunked_array(res_list))

if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
else:
if sub == "":
# GH#56792
res_list = self._apply_elementwise(
lambda val: val.find(sub, start, end)
)
return self._convert_int_result(pa.chunked_array(res_list))
if start is None:
start_offset = 0
start = 0
elif start < 0:
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
else:
start_offset = start
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
offset_result = pc.add(result, start_offset)
result = pc.if_else(found, offset_result, -1)
return self._convert_int_result(result)
17 changes: 0 additions & 17 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2348,23 +2348,6 @@ def _str_fullmatch(
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
start_offset = max(0, start)
offset_result = pc.add(result, start_offset)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
raise NotImplementedError(
f"find not implemented with {sub=}, {start=}, {end=}"
)
return type(self)(result)

def _str_join(self, sep: str):
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
self._pa_array.type
Expand Down
18 changes: 7 additions & 11 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,18 +416,14 @@ def _str_count(self, pat: str, flags: int = 0):
return self._convert_int_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
if (
pa_version_under13p0
and not (start != 0 and end is not None)
and not (start == 0 and end is None)
):
# GH#59562
return super()._str_find(sub, start, end)
return self._convert_int_result(result)
return ArrowStringArrayMixin._str_find(self, sub, start, end)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
Expand Down
52 changes: 49 additions & 3 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -1925,10 +1925,56 @@ def test_str_find_negative_start():
tm.assert_series_equal(result, expected)


def test_str_find_notimplemented():
def test_str_find_no_end():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
with pytest.raises(NotImplementedError, match="find not implemented"):
ser.str.find("ab", start=1)
result = ser.str.find("ab", start=1)
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
tm.assert_series_equal(result, expected)


def test_str_find_negative_start_negative_end():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub="d", start=-6, end=-3)
expected = pd.Series([3, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


def test_str_find_large_start():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub="d", start=16)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


@pytest.mark.skipif(
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
)
@pytest.mark.parametrize("start", [-15, -3, 0, 1, 15, None])
@pytest.mark.parametrize("end", [-15, -1, 0, 3, 15, None])
@pytest.mark.parametrize("sub", ["", "az", "abce", "a", "caa"])
def test_str_find_e2e(start, end, sub):
s = pd.Series(
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
dtype=ArrowDtype(pa.string()),
)
object_series = s.astype(pd.StringDtype(storage="python"))
result = s.str.find(sub, start, end)
expected = object_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result, expected)

arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result2, expected)


def test_str_find_negative_start_negative_end_no_match():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
result = ser.str.find(sub="d", start=-3, end=-6)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
Expand Down

0 comments on commit 44325c1

Please sign in to comment.