From 236d89b8565df844da0badfbc9f2db7084883933 Mon Sep 17 00:00:00 2001 From: mutricyl <118692416+mutricyl@users.noreply.github.com> Date: Sat, 6 Jul 2024 18:14:37 +0200 Subject: [PATCH] update algo.take to solve #59177 (#59181) * update algo.take to solve #59177 * forgot to update TestExtensionTake::test_take_coerces_list * fixing pandas/tests/dtypes/test_generic.py::TestABCClasses::test_abc_hierarchy * ABCExtensionArray set formatting --------- Co-authored-by: Laurent Mutricy --- pandas/core/algorithms.py | 10 +++++++--- pandas/tests/test_take.py | 10 +++++++++- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pandas/core/algorithms.py b/pandas/core/algorithms.py index 0d97f8a298fdb..92bd55cac9c5e 100644 --- a/pandas/core/algorithms.py +++ b/pandas/core/algorithms.py @@ -68,6 +68,7 @@ ABCExtensionArray, ABCIndex, ABCMultiIndex, + ABCNumpyExtensionArray, ABCSeries, ABCTimedeltaArray, ) @@ -1161,11 +1162,14 @@ def take( ... ) array([ 10, 10, -10]) """ - if not isinstance(arr, (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries)): + if not isinstance( + arr, + (np.ndarray, ABCExtensionArray, ABCIndex, ABCSeries, ABCNumpyExtensionArray), + ): # GH#52981 raise TypeError( - "pd.api.extensions.take requires a numpy.ndarray, " - f"ExtensionArray, Index, or Series, got {type(arr).__name__}." + "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, " + f"Index, Series, or NumpyExtensionArray got {type(arr).__name__}." ) indices = ensure_platform_int(indices) diff --git a/pandas/tests/test_take.py b/pandas/tests/test_take.py index ce2e4e0f6cec5..451ef42fff3d1 100644 --- a/pandas/tests/test_take.py +++ b/pandas/tests/test_take.py @@ -5,6 +5,7 @@ from pandas._libs import iNaT +from pandas import array import pandas._testing as tm import pandas.core.algorithms as algos @@ -303,7 +304,14 @@ def test_take_coerces_list(self): arr = [1, 2, 3] msg = ( "pd.api.extensions.take requires a numpy.ndarray, ExtensionArray, " - "Index, or Series, got list" + "Index, Series, or NumpyExtensionArray got list" ) with pytest.raises(TypeError, match=msg): algos.take(arr, [0, 0]) + + def test_take_NumpyExtensionArray(self): + # GH#59177 + arr = array([1 + 1j, 2, 3]) # NumpyEADtype('complex128') (NumpyExtensionArray) + assert algos.take(arr, [2]) == 2 + arr = array([1, 2, 3]) # Int64Dtype() (ExtensionArray) + assert algos.take(arr, [2]) == 2