Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REF (string): de-duplicate ArrowStringArray methods #59555

Merged
merged 9 commits into from
Sep 11, 2024
83 changes: 83 additions & 0 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from functools import partial
import re
from typing import (
TYPE_CHECKING,
Any,
Expand Down Expand Up @@ -48,6 +49,37 @@ def _convert_int_result(self, result):
def _apply_elementwise(self, func: Callable) -> list[list[Any]]:
raise NotImplementedError

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_pad(
self,
width: int,
Expand Down Expand Up @@ -128,6 +160,33 @@ def _str_slice_replace(
stop = np.iinfo(np.int64).max
return type(self)(pc.utf8_replace_slice(self._pa_array, start, stop, repl))

def _str_replace(
self,
pat: str | re.Pattern,
repl: str | Callable,
n: int = -1,
case: bool = True,
flags: int = 0,
regex: bool = True,
) -> Self:
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
raise NotImplementedError(
"replace is not supported with a re.Pattern, callable repl, "
"case=False, or flags!=0"
)

func = pc.replace_substring_regex if regex else pc.replace_substring
# https://github.com/apache/arrow/issues/39149
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
pa_max_replacements = None if n < 0 else n
result = func(
self._pa_array,
pattern=pat,
replacement=repl,
max_replacements=pa_max_replacements,
)
return type(self)(result)

def _str_capitalize(self) -> Self:
return type(self)(pc.utf8_capitalize(self._pa_array))

Expand All @@ -137,6 +196,16 @@ def _str_title(self) -> Self:
def _str_swapcase(self) -> Self:
return type(self)(pc.utf8_swapcase(self._pa_array))

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
predicate = lambda val: val.removeprefix(prefix)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
Expand Down Expand Up @@ -228,6 +297,20 @@ def _str_contains(
result = result.fill_null(na)
return self._convert_bool_result(result)

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if (
pa_version_under13p0
Expand Down
86 changes: 1 addition & 85 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1999,7 +1999,7 @@ def _rank(
"""
See Series.rank.__doc__.
"""
return type(self)(
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
Expand Down Expand Up @@ -2323,57 +2323,13 @@ def _str_count(self, pat: str, flags: int = 0) -> Self:
raise NotImplementedError(f"count not implemented with {flags=}")
return type(self)(pc.count_substring_regex(self._pa_array, pat))

def _result_converter(self, result):
return type(self)(result)

def _str_replace(
self,
pat: str | re.Pattern,
repl: str | Callable,
n: int = -1,
case: bool = True,
flags: int = 0,
regex: bool = True,
) -> Self:
if isinstance(pat, re.Pattern) or callable(repl) or not case or flags:
raise NotImplementedError(
"replace is not supported with a re.Pattern, callable repl, "
"case=False, or flags!=0"
)

func = pc.replace_substring_regex if regex else pc.replace_substring
# https://github.com/apache/arrow/issues/39149
# GH 56404, unexpected behavior with negative max_replacements with pyarrow.
pa_max_replacements = None if n < 0 else n
result = func(
self._pa_array,
pattern=pat,
replacement=repl,
max_replacements=pa_max_replacements,
)
return type(self)(result)

def _str_repeat(self, repeats: int | Sequence[int]) -> Self:
if not isinstance(repeats, int):
raise NotImplementedError(
f"repeat is not implemented when repeats is {type(repeats).__name__}"
)
return type(self)(pc.binary_repeat(self._pa_array, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
) -> Self:
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
) -> Self:
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_join(self, sep: str) -> Self:
if pa.types.is_string(self._pa_array.type) or pa.types.is_large_string(
self._pa_array.type
Expand All @@ -2394,46 +2350,6 @@ def _str_rpartition(self, sep: str, expand: bool) -> Self:
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_len(self) -> Self:
return type(self)(pc.utf8_length(self._pa_array))

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
predicate = lambda val: val.removeprefix(prefix)
result = self._apply_elementwise(predicate)
return type(self)(pa.chunked_array(result))

def _str_casefold(self) -> Self:
predicate = lambda val: val.casefold()
result = self._apply_elementwise(predicate)
Expand Down
106 changes: 19 additions & 87 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,8 @@

from pandas._typing import (
ArrayLike,
AxisInt,
Dtype,
NpDtype,
Scalar,
Self,
npt,
)
Expand Down Expand Up @@ -294,6 +292,20 @@ def astype(self, dtype, copy: bool = True):
_str_startswith = ArrowStringArrayMixin._str_startswith
_str_endswith = ArrowStringArrayMixin._str_endswith
_str_pad = ArrowStringArrayMixin._str_pad
_str_match = ArrowStringArrayMixin._str_match
_str_fullmatch = ArrowStringArrayMixin._str_fullmatch
_str_lower = ArrowStringArrayMixin._str_lower
_str_upper = ArrowStringArrayMixin._str_upper
_str_strip = ArrowStringArrayMixin._str_strip
_str_lstrip = ArrowStringArrayMixin._str_lstrip
_str_rstrip = ArrowStringArrayMixin._str_rstrip
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_get = ArrowStringArrayMixin._str_get
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace
_str_len = ArrowStringArrayMixin._str_len
_str_slice = ArrowStringArrayMixin._str_slice

def _str_contains(
Expand Down Expand Up @@ -331,73 +343,21 @@ def _str_replace(
fallback_performancewarning()
return super()._str_replace(pat, repl, n, case, flags, regex)

return ArrowExtensionArray._str_replace(self, pat, repl, n, case, flags, regex)
return ArrowStringArrayMixin._str_replace(
self, pat, repl, n, case, flags, regex
)

def _str_repeat(self, repeats: int | Sequence[int]):
if not isinstance(repeats, int):
return super()._str_repeat(repeats)
else:
return type(self)(pc.binary_repeat(self._pa_array, repeats))

def _str_match(
self, pat: str, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.startswith("^"):
pat = f"^{pat}"
return self._str_contains(pat, case, flags, na, regex=True)

def _str_fullmatch(
self, pat, case: bool = True, flags: int = 0, na: Scalar | None = None
):
if not pat.endswith("$") or pat.endswith("\\$"):
pat = f"{pat}$"
return self._str_match(pat, case, flags, na)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
return self._convert_int_result(result)

def _str_lower(self) -> Self:
return type(self)(pc.utf8_lower(self._pa_array))

def _str_upper(self) -> Self:
return type(self)(pc.utf8_upper(self._pa_array))

def _str_strip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_trim_whitespace(self._pa_array)
else:
result = pc.utf8_trim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_lstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_ltrim_whitespace(self._pa_array)
else:
result = pc.utf8_ltrim(self._pa_array, characters=to_strip)
return type(self)(result)

def _str_rstrip(self, to_strip=None) -> Self:
if to_strip is None:
result = pc.utf8_rtrim_whitespace(self._pa_array)
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)
return ArrowExtensionArray._str_repeat(self, repeats=repeats)

def _str_removeprefix(self, prefix: str):
if not pa_version_under13p0:
starts_with = pc.starts_with(self._pa_array, pattern=prefix)
removed = pc.utf8_slice_codeunits(self._pa_array, len(prefix))
result = pc.if_else(starts_with, removed, self._pa_array)
return type(self)(result)
return ArrowStringArrayMixin._str_removeprefix(self, prefix)
return super()._str_removeprefix(prefix)

def _str_removesuffix(self, suffix: str):
ends_with = pc.ends_with(self._pa_array, pattern=suffix)
removed = pc.utf8_slice_codeunits(self._pa_array, 0, stop=-len(suffix))
result = pc.if_else(ends_with, removed, self._pa_array)
return type(self)(result)

def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
Expand Down Expand Up @@ -464,28 +424,6 @@ def _reduce(
else:
return result

def _rank(
self,
*,
axis: AxisInt = 0,
method: str = "average",
na_option: str = "keep",
ascending: bool = True,
pct: bool = False,
):
"""
See Series.rank.__doc__.
"""
return self._convert_int_result(
self._rank_calc(
axis=axis,
method=method,
na_option=na_option,
ascending=ascending,
pct=pct,
)
)

def value_counts(self, dropna: bool = True) -> Series:
result = super().value_counts(dropna=dropna)
if self.dtype.na_value is np.nan:
Expand All @@ -507,9 +445,3 @@ def _cmp_method(self, other, op):

class ArrowStringArrayNumpySemantics(ArrowStringArray):
_na_value = np.nan
_str_get = ArrowStringArrayMixin._str_get
_str_removesuffix = ArrowStringArrayMixin._str_removesuffix
_str_capitalize = ArrowStringArrayMixin._str_capitalize
_str_title = ArrowStringArrayMixin._str_title
_str_swapcase = ArrowStringArrayMixin._str_swapcase
_str_slice_replace = ArrowStringArrayMixin._str_slice_replace
Loading