Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement Arrow String Array that is compatible with NumPy semantics #54533

Merged
merged 24 commits into from
Aug 23, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pandas/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -1321,6 +1321,7 @@ def nullable_string_dtype(request):
params=[
"python",
pytest.param("pyarrow", marks=td.skip_if_no("pyarrow")),
pytest.param("pyarrow_numpy", marks=td.skip_if_no("pyarrow")),
]
)
def string_storage(request):
phofl marked this conversation as resolved.
Show resolved Hide resolved
Expand All @@ -1329,6 +1330,7 @@ def string_storage(request):

* 'python'
* 'pyarrow'
* 'pyarrow_numpy'
"""
return request.param

Expand Down Expand Up @@ -1380,6 +1382,7 @@ def object_dtype(request):
"object",
"string[python]",
pytest.param("string[pyarrow]", marks=td.skip_if_no("pyarrow")),
pytest.param("string[pyarrow_numpy]", marks=td.skip_if_no("pyarrow")),
]
)
def any_string_dtype(request):
Expand Down Expand Up @@ -2000,4 +2003,4 @@ def warsaw(request) -> str:

@pytest.fixture()
def arrow_string_storage():
return ("pyarrow",)
return ("pyarrow", "pyarrow_numpy")
5 changes: 4 additions & 1 deletion pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -506,7 +506,10 @@ def __getitem__(self, item: PositionalIndexer):
if isinstance(item, np.ndarray):
if not len(item):
# Removable once we migrate StringDtype[pyarrow] to ArrowDtype[string]
if self._dtype.name == "string" and self._dtype.storage == "pyarrow":
if self._dtype.name == "string" and self._dtype.storage in (
"pyarrow",
"pyarrow_numpy",
):
pa_dtype = pa.string()
else:
pa_dtype = self._dtype.pyarrow_dtype
Expand Down
21 changes: 16 additions & 5 deletions pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ class StringDtype(StorageExtensionDtype):

Parameters
----------
storage : {"python", "pyarrow"}, optional
storage : {"python", "pyarrow", "pyarrow_numpy"}, optional
If not given, the value of ``pd.options.mode.string_storage``.

Attributes
Expand Down Expand Up @@ -108,11 +108,11 @@ def na_value(self) -> libmissing.NAType:
def __init__(self, storage=None) -> None:
if storage is None:
storage = get_option("mode.string_storage")
if storage not in {"python", "pyarrow"}:
if storage not in {"python", "pyarrow", "pyarrow_numpy"}:
raise ValueError(
f"Storage must be 'python' or 'pyarrow'. Got {storage} instead."
)
if storage == "pyarrow" and pa_version_under7p0:
if storage in ("pyarrow", "pyarrow_numpy") and pa_version_under7p0:
raise ImportError(
"pyarrow>=7.0.0 is required for PyArrow backed StringArray."
)
Expand Down Expand Up @@ -160,6 +160,8 @@ def construct_from_string(cls, string):
return cls(storage="python")
elif string == "string[pyarrow]":
return cls(storage="pyarrow")
elif string == "string[pyarrow_numpy]":
return cls(storage="pyarrow_numpy")
else:
raise TypeError(f"Cannot construct a '{cls.__name__}' from '{string}'")

Expand All @@ -176,12 +178,17 @@ def construct_array_type( # type: ignore[override]
-------
type
"""
from pandas.core.arrays.string_arrow import ArrowStringArray
from pandas.core.arrays.string_arrow import (
ArrowStringArray,
ArrowStringArrayNumpySemantics,
)

if self.storage == "python":
return StringArray
else:
elif self.storage == "pyarrow":
return ArrowStringArray
else:
return ArrowStringArrayNumpySemantics

def __from_arrow__(
self, array: pyarrow.Array | pyarrow.ChunkedArray
Expand All @@ -193,6 +200,10 @@ def __from_arrow__(
from pandas.core.arrays.string_arrow import ArrowStringArray

return ArrowStringArray(array)
elif self.storage == "pyarrow_numpy":
from pandas.core.arrays.string_arrow import ArrowStringArrayNumpySemantics

return ArrowStringArrayNumpySemantics(array)
else:
import pyarrow

Expand Down
149 changes: 135 additions & 14 deletions pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from functools import partial
import re
from typing import (
TYPE_CHECKING,
Expand Down Expand Up @@ -27,6 +28,7 @@
)
from pandas.core.dtypes.missing import isna

from pandas.core.arrays._arrow_string_mixins import ArrowStringArrayMixin
from pandas.core.arrays.arrow import ArrowExtensionArray
from pandas.core.arrays.boolean import BooleanDtype
from pandas.core.arrays.integer import Int64Dtype
Expand Down Expand Up @@ -113,10 +115,11 @@ class ArrowStringArray(ObjectStringArrayMixin, ArrowExtensionArray, BaseStringAr
# error: Incompatible types in assignment (expression has type "StringDtype",
# base class "ArrowExtensionArray" defined the type as "ArrowDtype")
_dtype: StringDtype # type: ignore[assignment]
_storage = "pyarrow"

def __init__(self, values) -> None:
super().__init__(values)
self._dtype = StringDtype(storage="pyarrow")
self._dtype = StringDtype(storage=self._storage)

if not pa.types.is_string(self._pa_array.type) and not (
pa.types.is_dictionary(self._pa_array.type)
Expand Down Expand Up @@ -144,7 +147,10 @@ def _from_sequence(cls, scalars, dtype: Dtype | None = None, copy: bool = False)

if dtype and not (isinstance(dtype, str) and dtype == "string"):
dtype = pandas_dtype(dtype)
assert isinstance(dtype, StringDtype) and dtype.storage == "pyarrow"
assert isinstance(dtype, StringDtype) and dtype.storage in (
"pyarrow",
"pyarrow_numpy",
)

if isinstance(scalars, BaseMaskedArray):
# avoid costly conversion to object dtype in ensure_string_array and
Expand Down Expand Up @@ -178,6 +184,10 @@ def insert(self, loc: int, item) -> ArrowStringArray:
raise TypeError("Scalar must be NA or str")
return super().insert(loc, item)

@classmethod
def _result_converter(cls, values, na=None):
return BooleanDtype().__from_arrow__(values)

def _maybe_convert_setitem_value(self, value):
"""Maybe convert value to be pyarrow compatible."""
if is_scalar(value):
Expand Down Expand Up @@ -313,7 +323,7 @@ def _str_contains(
result = pc.match_substring_regex(self._pa_array, pat, ignore_case=not case)
else:
result = pc.match_substring(self._pa_array, pat, ignore_case=not case)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result, na=na)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -322,7 +332,7 @@ def _str_startswith(self, pat: str, na=None):
result = pc.starts_with(self._pa_array, pattern=pat)
if not isna(na):
result = result.fill_null(na)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand All @@ -331,7 +341,7 @@ def _str_endswith(self, pat: str, na=None):
result = pc.ends_with(self._pa_array, pattern=pat)
if not isna(na):
result = result.fill_null(na)
result = BooleanDtype().__from_arrow__(result)
result = self._result_converter(result)
if not isna(na):
result[isna(result)] = bool(na)
return result
Expand Down Expand Up @@ -369,39 +379,39 @@ def _str_fullmatch(

def _str_isalnum(self):
result = pc.utf8_is_alnum(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isalpha(self):
result = pc.utf8_is_alpha(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isdecimal(self):
result = pc.utf8_is_decimal(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isdigit(self):
result = pc.utf8_is_digit(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_islower(self):
result = pc.utf8_is_lower(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isnumeric(self):
result = pc.utf8_is_numeric(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isspace(self):
result = pc.utf8_is_space(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_istitle(self):
result = pc.utf8_is_title(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_isupper(self):
result = pc.utf8_is_upper(self._pa_array)
return BooleanDtype().__from_arrow__(result)
return self._result_converter(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array)
Expand Down Expand Up @@ -433,3 +443,114 @@ def _str_rstrip(self, to_strip=None):
else:
result = pc.utf8_rtrim(self._pa_array, characters=to_strip)
return type(self)(result)


class ArrowStringArrayNumpySemantics(ArrowStringArray):
_storage = "pyarrow_numpy"

@classmethod
def _result_converter(cls, values, na=None):
if not isna(na):
values = values.fill_null(bool(na))
return ArrowExtensionArray(values).to_numpy(na_value=np.nan)

def __getattribute__(self, item):
# ArrowStringArray and we both inherit from ArrowExtensionArray, which
# creates inheritance problems (Diamond inheritance)
if item in ArrowStringArrayMixin.__dict__ and item != "_pa_array":
return partial(getattr(ArrowStringArrayMixin, item), self)
return super().__getattribute__(item)

def _str_map(
self, f, na_value=None, dtype: Dtype | None = None, convert: bool = True
):
if dtype is None:
dtype = self.dtype
if na_value is None:
na_value = self.dtype.na_value

mask = isna(self)
arr = np.asarray(self)

if is_integer_dtype(dtype) or is_bool_dtype(dtype):
if is_integer_dtype(dtype):
na_value = np.nan
else:
na_value = False
try:
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
dtype=np.dtype(dtype), # type: ignore[arg-type]
)
return result

except ValueError:
result = lib.map_infer_mask(
arr,
f,
mask.view("uint8"),
convert=False,
na_value=na_value,
)
if convert and result.dtype == object:
result = lib.maybe_convert_objects(result)
return result

elif is_string_dtype(dtype) and not is_object_dtype(dtype):
# i.e. StringDtype
result = lib.map_infer_mask(
arr, f, mask.view("uint8"), convert=False, na_value=na_value
)
result = pa.array(result, mask=mask, type=pa.string(), from_pandas=True)
return type(self)(result)
else:
# This is when the result type is object. We reach this when
# -> We know the result type is truly object (e.g. .encode returns bytes
# or .findall returns a list).
# -> We don't know the result type. E.g. `.get` can return anything.
return lib.map_infer_mask(arr, f, mask.view("uint8"))

def _convert_int_dtype(self, result):
if result.dtype == np.int32:
result = result.astype(np.int64)
return result

def _str_count(self, pat: str, flags: int = 0):
if flags:
return super()._str_count(pat, flags)
result = pc.count_substring_regex(self._pa_array, pat).to_numpy()
return self._convert_int_dtype(result)

def _str_len(self):
result = pc.utf8_length(self._pa_array).to_numpy()
return self._convert_int_dtype(result)

def _str_find(self, sub: str, start: int = 0, end: int | None = None):
if start != 0 and end is not None:
slices = pc.utf8_slice_codeunits(self._pa_array, start, stop=end)
result = pc.find_substring(slices, sub)
not_found = pc.equal(result, -1)
offset_result = pc.add(result, end - start)
result = pc.if_else(not_found, result, offset_result)
elif start == 0 and end is None:
slices = self._pa_array
result = pc.find_substring(slices, sub)
else:
return super()._str_find(sub, start, end)
return self._convert_int_dtype(result.to_numpy())

def _cmp_method(self, other, op):
result = super()._cmp_method(other, op)
return result.to_numpy(np.bool_, na_value=False)

def value_counts(self, dropna: bool = True):
from pandas import Series

result = super().value_counts(dropna)
return Series(
result._values.to_numpy(), index=result.index, name=result.name, copy=False
)
2 changes: 1 addition & 1 deletion pandas/core/config_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,7 +500,7 @@ def use_inf_as_na_cb(key) -> None:
"string_storage",
"python",
string_storage_doc,
validator=is_one_of_factory(["python", "pyarrow"]),
validator=is_one_of_factory(["python", "pyarrow", "pyarrow_numpy"]),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want the user to allow setting this for the new option as well? Because we also have the new pd.options.future.infer_strings option to enable the future string dtype?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd say yes, you might want to astype columns after operations for example

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably something to further discuss in a follow up issue, but I would expect that if you opt-in for the future string dtype with pd.options.future.infer_strings, that this would also automatically set the "pyarrow_numpy" string storage as default for operations that depend on that (like doing astype("string"))

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah agree, let's do this as a follow up

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Opened #54793 for this interaction between pd.options.future.infer_strings and the default string_storage

)


Expand Down
4 changes: 3 additions & 1 deletion pandas/core/strings/accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,9 @@ def _map_and_wrap(name: str | None, docstring: str | None):
@forbid_nonstring_types(["bytes"], name=name)
def wrapper(self):
result = getattr(self._data.array, f"_str_{name}")()
return self._wrap_result(result)
return self._wrap_result(
result, returns_string=name not in ("isnumeric", "isdecimal")
)

wrapper.__doc__ = docstring
return wrapper
Expand Down
Loading