Skip to content

Commit

Permalink
REF (string): de-duplicate _str_contains (#59709)
Browse files Browse the repository at this point in the history
* REF: de-duplicate _str_contains

* pyright ignore
  • Loading branch information
jbrockmendel authored and jorisvandenbossche committed Oct 10, 2024
1 parent 5b571c0 commit 40d81db
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 25 deletions.
15 changes: 15 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,3 +186,18 @@ def _str_istitle(self):
def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return self._convert_bool_result(result)

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")

if regex:
pa_contains = pc.match_substring_regex
else:
pa_contains = pc.match_substring
result = pa_contains(self._pa_array, pat, ignore_case=not case)
if not isna(na): # pyright: ignore [reportGeneralTypeIssues]
result = result.fill_null(na)
return self._convert_bool_result(result)
15 changes: 0 additions & 15 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2296,21 +2296,6 @@ def _str_count(self, pat: str, flags: int = 0):
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._pa_array, pat))

def _str_contains(
self, pat, case: bool = True, flags: int = 0, na=None, regex: bool = True
):
if flags:
raise NotImplementedError(f"contains not implemented with {flags=}")

if regex:
pa_contains = pc.match_substring_regex
else:
pa_contains = pc.match_substring
result = pa_contains(self._pa_array, pat, ignore_case=not case)
if not isna(na):
result = result.fill_null(na)
return type(self)(result)

def _result_converter(self, result):
return type(self)(result)

Expand Down
14 changes: 4 additions & 10 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,10 +214,8 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

def _convert_bool_result(self, values, na=None):
def _convert_bool_result(self, values):
if self.dtype.na_value is np.nan:
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)
return BooleanDtype().__from_arrow__(values)

Expand Down Expand Up @@ -305,11 +303,6 @@ def _str_contains(
fallback_performancewarning()
return super()._str_contains(pat, case, flags, na, regex)

if regex:
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
else:
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
result = self._convert_bool_result(result, na=na)
if not isna(na):
if not isinstance(na, bool):
# GH#59561
Expand All @@ -319,8 +312,9 @@ def _str_contains(
FutureWarning,
stacklevel=find_stack_level(),
)
result[isna(result)] = bool(na)
return result
na = bool(na)

return ArrowStringArrayMixin._str_contains(self, pat, case, flags, na, regex)

def _str_replace(
self,
Expand Down

0 comments on commit 40d81db

Please sign in to comment.