Skip to content

Commit

Permalink
BUG (string): ArrowStringArray.find corner cases (#59562)
Browse files Browse the repository at this point in the history
  • Loading branch information
jbrockmendel authored Sep 6, 2024
1 parent 08431f1 commit 3f8d3e4
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 55 deletions.
44 changes: 43 additions & 1 deletion pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@
from functools import partial
from typing import (
TYPE_CHECKING,
Any,
Literal,
)

import numpy as np

from pandas.compat import (
pa_version_under10p1,
pa_version_under13p0,
pa_version_under17p0,
)

Expand All @@ -20,7 +22,10 @@
import pyarrow.compute as pc

if TYPE_CHECKING:
from collections.abc import Sized
from collections.abc import (
Callable,
Sized,
)

from pandas._typing import (
Scalar,
Expand All @@ -42,6 +47,9 @@ def _convert_int_result(self, result):
# Convert an integer-dtype result to the appropriate result type
raise NotImplementedError

def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
raise NotImplementedError

def _str_pad(
self,
width: int,
Expand Down Expand Up @@ -205,3 +213,37 @@ def _str_contains(
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if (
pa_version_under13p0
and not (start != 0 and end is not None)
and not (start == 0 and end is None)
):
# GH#59562
res_list = self._apply_elementwise(lambda val: val.find(sub, start, end))
return self._convert_int_result(pa.chunked_array(res_list))

if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
else:
if sub == "":
# GH#56792
res_list = self._apply_elementwise(
lambda val: val.find(sub, start, end)
)
return self._convert_int_result(pa.chunked_array(res_list))
if start is None:
start_offset = 0
start = 0
elif start < 0:
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
else:
start_offset = start
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
offset_result = pc.add(result, start_offset)
result = pc.if_else(found, offset_result, -1)
return self._convert_int_result(result)
23 changes: 0 additions & 23 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2373,29 +2373,6 @@ def _str_fullmatch(
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None) -> Self:
if (start == 0 or start is None) and end is None:
result = pc.find_substring(self._pa_array, sub)
else:
if sub == "":
# GH 56792
result = self._apply_elementwise(lambda val: val.find(sub, start, end))
return type(self)(pa.chunked_array(result))
if start is None:
start_offset = 0
start = 0
elif start < 0:
start_offset = pc.add(start, pc.utf8_length(self._pa_array))
start_offset = pc.if_else(pc.less(start_offset, 0), 0, start_offset)
else:
start_offset = start
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
found = pc.not_equal(result, pa.scalar(-1, type=result.type))
offset_result = pc.add(result, start_offset)
result = pc.if_else(found, offset_result, -1)
return type(self)(result)

def _str_join(self, sep: str) -> Self:
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
self._pa_array.type
Expand Down
18 changes: 7 additions & 11 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,18 +416,14 @@ def _str_count(self, pat: str, flags: int = 0):
return self._convert_int_result(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
if (
pa_version_under13p0
and not (start != 0 and end is not None)
and not (start == 0 and end is None)
):
# GH#59562
return super()._str_find(sub, start, end)
return self._convert_int_result(result)
return ArrowStringArrayMixin._str_find(self, sub, start, end)

def _str_get_dummies(self, sep: str = "|"):
dummies_pa, labels = ArrowExtensionArray(self._pa_array)._str_get_dummies(sep)
Expand Down
31 changes: 11 additions & 20 deletions pandas/tests/extension/test_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas._libs import lib
from pandas._libs.tslibs import timezones
from pandas.compat import (
Expand Down Expand Up @@ -1947,14 +1945,9 @@ def test_str_find_negative_start():

def test_str_find_no_end():
ser = pd.Series(["abc", None], dtype=ArrowDtype(pa.string()))
if pa_version_under13p0:
# https://github.com/apache/arrow/issues/36311
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
ser.str.find("ab", start=1)
else:
result = ser.str.find("ab", start=1)
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
tm.assert_series_equal(result, expected)
result = ser.str.find("ab", start=1)
expected = pd.Series([-1, None], dtype="int64[pyarrow]")
tm.assert_series_equal(result, expected)


def test_str_find_negative_start_negative_end():
Expand All @@ -1968,17 +1961,11 @@ def test_str_find_negative_start_negative_end():
def test_str_find_large_start():
# GH 56791
ser = pd.Series(["abcdefg", None], dtype=ArrowDtype(pa.string()))
if pa_version_under13p0:
# https://github.com/apache/arrow/issues/36311
with pytest.raises(pa.lib.ArrowInvalid, match="Negative buffer resize"):
ser.str.find(sub="d", start=16)
else:
result = ser.str.find(sub="d", start=16)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)
result = ser.str.find(sub="d", start=16)
expected = pd.Series([-1, None], dtype=ArrowDtype(pa.int64()))
tm.assert_series_equal(result, expected)


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
@pytest.mark.skipif(
pa_version_under13p0, reason="https://github.com/apache/arrow/issues/36311"
)
Expand All @@ -1990,11 +1977,15 @@ def test_str_find_e2e(start, end, sub):
["abcaadef", "abc", "abcdeddefgj8292", "ab", "a", ""],
dtype=ArrowDtype(pa.string()),
)
object_series = s.astype(pd.StringDtype())
object_series = s.astype(pd.StringDtype(storage="python"))
result = s.str.find(sub, start, end)
expected = object_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result, expected)

arrow_str_series = s.astype(pd.StringDtype(storage="pyarrow"))
result2 = arrow_str_series.str.find(sub, start, end).astype(result.dtype)
tm.assert_series_equal(result2, expected)


def test_str_find_negative_start_negative_end_no_match():
# GH 56791
Expand Down

0 comments on commit 3f8d3e4

Please sign in to comment.