Skip to content

Commit

Permalink
String dtype: avoid surfacing pyarrow exception in binary operations (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jorisvandenbossche committed Oct 10, 2024
1 parent 60175cc commit daa46c1
Show file tree
Hide file tree
Showing 16 changed files with 129 additions and 273 deletions.
40 changes: 33 additions & 7 deletions pandas/core/arrays/arrow/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,12 @@ def __invert__(self) -> Self:
return type(self)(pc.invert(self._pa_array))

def __neg__(self) -> Self:
return type(self)(pc.negate_checked(self._pa_array))
try:
return type(self)(pc.negate_checked(self._pa_array))
except pa.ArrowNotImplementedError as err:
raise TypeError(
f"unary '-' not supported for dtype '{self.dtype}'"
) from err

def __pos__(self) -> Self:
return type(self)(self._pa_array)
Expand Down Expand Up @@ -731,8 +736,19 @@ def _cmp_method(self, other, op):
)
return ArrowExtensionArray(result)

def _evaluate_op_method(self, other, op, arrow_funcs):
def _op_method_error_message(self, other, op) -> str:
if hasattr(other, "dtype"):
other_type = f"dtype '{other.dtype}'"
else:
other_type = f"object of type {type(other)}"
return (
f"operation '{op.__name__}' not supported for "
f"dtype '{self.dtype}' with {other_type}"
)

def _evaluate_op_method(self, other, op, arrow_funcs) -> Self:
pa_type = self._pa_array.type
other_original = other
other = self._box_pa(other)

if (
Expand All @@ -742,10 +758,15 @@ def _evaluate_op_method(self, other, op, arrow_funcs):
):
if op in [operator.add, roperator.radd]:
sep = pa.scalar("", type=pa_type)
if op is operator.add:
result = pc.binary_join_element_wise(self._pa_array, other, sep)
elif op is roperator.radd:
result = pc.binary_join_element_wise(other, self._pa_array, sep)
try:
if op is operator.add:
result = pc.binary_join_element_wise(self._pa_array, other, sep)
elif op is roperator.radd:
result = pc.binary_join_element_wise(other, self._pa_array, sep)
except pa.ArrowNotImplementedError as err:
raise TypeError(
self._op_method_error_message(other_original, op)
) from err
return type(self)(result)
elif op in [operator.mul, roperator.rmul]:
binary = self._pa_array
Expand Down Expand Up @@ -777,9 +798,14 @@ def _evaluate_op_method(self, other, op, arrow_funcs):

pc_func = arrow_funcs[op.__name__]
if pc_func is NotImplemented:
if pa.types.is_string(pa_type) or pa.types.is_large_string(pa_type):
raise TypeError(self._op_method_error_message(other_original, op))
raise NotImplementedError(f"{op.__name__} not implemented.")

result = pc_func(self._pa_array, other)
try:
result = pc_func(self._pa_array, other)
except pa.ArrowNotImplementedError as err:
raise TypeError(self._op_method_error_message(other_original, op)) from err
return type(self)(result)

def _logical_method(self, other, op):
Expand Down
5 changes: 4 additions & 1 deletion pandas/core/arrays/string_.py
Original file line number Diff line number Diff line change
Expand Up @@ -823,8 +823,11 @@ def _cmp_method(self, other, op):
f"Lengths of operands do not match: {len(self)} != {len(other)}"
)

other = np.asarray(other)
# for array-likes, first filter out NAs before converting to numpy
if not is_array_like(other):
other = np.asarray(other)
other = other[valid]
other = np.asarray(other)

if op.__name__ in ops.ARITHMETIC_BINOPS:
result = np.empty_like(self._ndarray, dtype="object")
Expand Down
25 changes: 6 additions & 19 deletions pandas/tests/arithmetic/test_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW
import pandas.util._test_decorators as td

import pandas as pd
Expand Down Expand Up @@ -318,27 +315,17 @@ def test_add(self):
expected = pd.Index(["1a", "1b", "1c"])
tm.assert_index_equal("1" + index, expected)

@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_sub_fail(self, using_infer_string):
def test_sub_fail(self):
index = pd.Index([str(i) for i in range(10)])

if using_infer_string:
import pyarrow as pa

err = pa.lib.ArrowNotImplementedError
msg = "has no kernel"
else:
err = TypeError
msg = "unsupported operand type|Cannot broadcast"
with pytest.raises(err, match=msg):
msg = "unsupported operand type|Cannot broadcast|sub' not supported"
with pytest.raises(TypeError, match=msg):
index - "a"
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
index - index
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
index - index.tolist()
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
index.tolist() - index

def test_sub_object(self):
Expand Down
26 changes: 7 additions & 19 deletions pandas/tests/arrays/boolean/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

from pandas.compat import HAS_PYARROW

import pandas as pd
import pandas._testing as tm

Expand Down Expand Up @@ -94,19 +90,8 @@ def test_op_int8(left_array, right_array, opname):
# -----------------------------------------------------------------------------


@pytest.mark.xfail(
using_string_dtype() and not HAS_PYARROW, reason="TODO(infer_string)"
)
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
def test_error_invalid_values(data, all_arithmetic_operators):
# invalid ops

if using_infer_string:
import pyarrow as pa

err = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
else:
err = TypeError

op = all_arithmetic_operators
s = pd.Series(data)
ops = getattr(s, op)
Expand All @@ -116,7 +101,8 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"did not contain a loop with signature matching types|"
"BooleanArray cannot perform the operation|"
"not supported for the input types, and the inputs could not be safely coerced "
"to any supported types according to the casting rule ''safe''"
"to any supported types according to the casting rule ''safe''|"
"not supported for dtype"
)
with pytest.raises(TypeError, match=msg):
ops("foo")
Expand All @@ -125,9 +111,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
r"unsupported operand type\(s\) for",
"Concatenation operation is not implemented for NumPy arrays",
"has no kernel",
"not supported for dtype",
]
)
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Timestamp("20180101"))

# invalid array-likes
Expand All @@ -140,7 +127,8 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"not all arguments converted during string formatting",
"has no kernel",
"not implemented",
"not supported for dtype",
]
)
with pytest.raises(err, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series("foo", index=s.index))
23 changes: 8 additions & 15 deletions pandas/tests/arrays/floating/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas as pd
import pandas._testing as tm
from pandas.core.arrays import FloatingArray
Expand Down Expand Up @@ -124,19 +122,11 @@ def test_arith_zero_dim_ndarray(other):
# -----------------------------------------------------------------------------


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
def test_error_invalid_values(data, all_arithmetic_operators):
op = all_arithmetic_operators
s = pd.Series(data)
ops = getattr(s, op)

if using_infer_string:
import pyarrow as pa

errs = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
else:
errs = TypeError

# invalid scalars
msg = "|".join(
[
Expand All @@ -152,15 +142,17 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"Concatenation operation is not implemented for NumPy arrays",
"has no kernel",
"not implemented",
"not supported for dtype",
"Can only string multiply by an integer",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops("foo")
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Timestamp("20180101"))

# invalid array-likes
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series("foo", index=s.index))

msg = "|".join(
Expand All @@ -181,9 +173,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"cannot subtract DatetimeArray from ndarray",
"has no kernel",
"not implemented",
"not supported for dtype",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series(pd.date_range("20180101", periods=len(s))))


Expand Down
34 changes: 11 additions & 23 deletions pandas/tests/arrays/integer/test_arithmetic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import numpy as np
import pytest

from pandas._config import using_string_dtype

import pandas as pd
import pandas._testing as tm
from pandas.core import ops
Expand Down Expand Up @@ -174,19 +172,11 @@ def test_numpy_zero_dim_ndarray(other):
# -----------------------------------------------------------------------------


@pytest.mark.xfail(using_string_dtype(), reason="TODO(infer_string)", strict=False)
def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string):
def test_error_invalid_values(data, all_arithmetic_operators):
op = all_arithmetic_operators
s = pd.Series(data)
ops = getattr(s, op)

if using_infer_string:
import pyarrow as pa

errs = (TypeError, pa.lib.ArrowNotImplementedError, NotImplementedError)
else:
errs = TypeError

# invalid scalars
msg = "|".join(
[
Expand All @@ -201,24 +191,21 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"has no kernel",
"not implemented",
"The 'out' kwarg is necessary. Use numpy.strings.multiply without it.",
"not supported for dtype",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops("foo")
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Timestamp("20180101"))

# invalid array-likes
str_ser = pd.Series("foo", index=s.index)
# with pytest.raises(TypeError, match=msg):
if (
all_arithmetic_operators
in [
"__mul__",
"__rmul__",
]
and not using_infer_string
): # (data[~data.isna()] >= 0).all():
if all_arithmetic_operators in [
"__mul__",
"__rmul__",
]: # (data[~data.isna()] >= 0).all():
res = ops(str_ser)
expected = pd.Series(["foo" * x for x in data], index=s.index)
expected = expected.fillna(np.nan)
Expand All @@ -227,7 +214,7 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
# more-correct than np.nan here.
tm.assert_series_equal(res, expected)
else:
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(str_ser)

msg = "|".join(
Expand All @@ -242,9 +229,10 @@ def test_error_invalid_values(data, all_arithmetic_operators, using_infer_string
"cannot subtract DatetimeArray from ndarray",
"has no kernel",
"not implemented",
"not supported for dtype",
]
)
with pytest.raises(errs, match=msg):
with pytest.raises(TypeError, match=msg):
ops(pd.Series(pd.date_range("20180101", periods=len(s))))


Expand Down
10 changes: 1 addition & 9 deletions pandas/tests/extension/base/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class BaseOpsUtil:

def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
) -> type[Exception] | tuple[type[Exception], ...] | None:
# Find the Exception, if any we expect to raise calling
# obj.__op_name__(other)

Expand All @@ -39,14 +39,6 @@ def _get_expected_exception(
else:
result = self.frame_scalar_exc

if using_string_dtype() and result is not None:
import pyarrow as pa

result = ( # type: ignore[assignment]
result,
pa.lib.ArrowNotImplementedError,
NotImplementedError,
)
return result

def _cast_pointwise_result(self, op_name: str, obj, other, pointwise_result):
Expand Down
2 changes: 1 addition & 1 deletion pandas/tests/extension/decimal/test_decimal.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def data_for_grouping():
class TestDecimalArray(base.ExtensionTests):
def _get_expected_exception(
self, op_name: str, obj, other
) -> type[Exception] | None:
) -> type[Exception] | tuple[type[Exception], ...] | None:
return None

def _supports_reduction(self, ser: pd.Series, op_name: str) -> bool:
Expand Down
Loading

0 comments on commit daa46c1

Please sign in to comment.