Skip to content
Closed
52 changes: 51 additions & 1 deletion pandas/core/arrays/numpy_.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
from pandas.compat.numpy import function as nv

from pandas.core.dtypes.astype import astype_array
from pandas.core.dtypes.cast import construct_1d_object_array_from_listlike
from pandas.core.dtypes.cast import (
construct_1d_object_array_from_listlike,
np_can_hold_element,
)
from pandas.core.dtypes.common import pandas_dtype
from pandas.core.dtypes.dtypes import NumpyEADtype
from pandas.core.dtypes.missing import isna
Expand Down Expand Up @@ -507,6 +510,53 @@ def to_numpy(

return result

def _validate_setitem_value(self, value):
if type(value) == int:
try:
np_can_hold_element(self.dtype, value)
except Exception:
pass
return value
elif type(value) == float:
if (
self.dtype
in [
NumpyEADtype("float32"),
NumpyEADtype("float64"),
NumpyEADtype("object"),
]
or self.dtype is None
):
return value
elif type(value) not in [int, float] and (
self.dtype
not in [
NumpyEADtype("int64"),
NumpyEADtype("float64"),
NumpyEADtype("uint16"),
NumpyEADtype("object"),
]
or lib.is_list_like(value)
):
return value
if self.dtype is None:
return value
if not isinstance(self.dtype, NumpyEADtype):
return value
if (
NumpyEADtype(type(value)) == NumpyEADtype(self.dtype)
or NumpyEADtype(type(value)) == self.dtype
):
return value
if self.dtype == NumpyEADtype("object"):
return value

raise TypeError(
"value cannot be inserted without changing the dtype. value:"
f"{value}, type(value): {type(value)}, NumpyEADtype(type(value)):"
f" {NumpyEADtype(type(value))}, self.dtype: {self.dtype}"
)

# ------------------------------------------------------------------------
# Ops

Expand Down
13 changes: 11 additions & 2 deletions pandas/tests/arrays/numpy_/test_numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def test_setitem_no_coercion():

# With a value that we do coerce, check that we coerce the value
# and not the underlying array.
arr[0] = 2.5
arr[0] = 2
assert isinstance(arr[0], (int, np.integer)), type(arr[0])


Expand All @@ -295,7 +295,7 @@ def test_setitem_preserves_views():
assert view2[0] == 9
assert view3[0] == 9

arr[-1] = 2.5
arr[-1] = 2
view1[-1] = 5
assert arr[-1] == 5

Expand All @@ -322,3 +322,12 @@ def test_factorize_unsigned():
tm.assert_numpy_array_equal(res_codes, exp_codes)

tm.assert_extension_array_equal(res_unique, NumpyExtensionArray(exp_unique))


def test_array_validate_setitem_value():
# Issue# 51044
arr = pd.Series(range(5)).array
with pytest.raises(TypeError, match="str"):
arr._validate_setitem_value("foo")
with pytest.raises(TypeError, match="float"):
arr._validate_setitem_value(1.5)