From 7c56e051e43104818143ed1cabc077755e70559d Mon Sep 17 00:00:00 2001 From: MeeseeksMachine <39504233+meeseeksmachine@users.noreply.github.com> Date: Sun, 2 Apr 2023 19:04:36 +0200 Subject: [PATCH] Backport PR #51966 on branch 2.0.x (CoW: __array__ not recognizing ea dtypes) (#52358) Backport PR #51966: CoW: __array__ not recognizing ea dtypes Co-authored-by: Patrick Hoefler <61934744+phofl@users.noreply.github.com> --- pandas/core/generic.py | 16 +++++-- pandas/core/series.py | 3 +- pandas/tests/copy_view/test_array.py | 64 ++++++++++++++++++++++++++++ 3 files changed, 77 insertions(+), 6 deletions(-) diff --git a/pandas/core/generic.py b/pandas/core/generic.py index 586d7d653599b..9e082b7883e74 100644 --- a/pandas/core/generic.py +++ b/pandas/core/generic.py @@ -101,6 +101,7 @@ validate_inclusive, ) +from pandas.core.dtypes.astype import astype_is_view from pandas.core.dtypes.common import ( ensure_object, ensure_platform_int, @@ -1995,10 +1996,17 @@ def empty(self) -> bool_t: def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray: values = self._values arr = np.asarray(values, dtype=dtype) - if arr is values and using_copy_on_write(): - # TODO(CoW) also properly handle extension dtypes - arr = arr.view() - arr.flags.writeable = False + if ( + astype_is_view(values.dtype, arr.dtype) + and using_copy_on_write() + and self._mgr.is_single_block + ): + # Check if both conversions can be done without a copy + if astype_is_view(self.dtypes.iloc[0], values.dtype) and astype_is_view( + values.dtype, arr.dtype + ): + arr = arr.view() + arr.flags.writeable = False return arr @final diff --git a/pandas/core/series.py b/pandas/core/series.py index 475920da20e88..22036abcfd497 100644 --- a/pandas/core/series.py +++ b/pandas/core/series.py @@ -920,8 +920,7 @@ def __array__(self, dtype: npt.DTypeLike | None = None) -> np.ndarray: """ values = self._values arr = np.asarray(values, dtype=dtype) - if arr is values and using_copy_on_write(): - # TODO(CoW) also properly handle extension dtypes + if using_copy_on_write() and astype_is_view(values.dtype, arr.dtype): arr = arr.view() arr.flags.writeable = False return arr diff --git a/pandas/tests/copy_view/test_array.py b/pandas/tests/copy_view/test_array.py index 519479881948b..62a6a3374e612 100644 --- a/pandas/tests/copy_view/test_array.py +++ b/pandas/tests/copy_view/test_array.py @@ -4,6 +4,7 @@ from pandas import ( DataFrame, Series, + date_range, ) import pandas._testing as tm from pandas.tests.copy_view.util import get_array @@ -119,3 +120,66 @@ def test_ravel_read_only(using_copy_on_write, order): if using_copy_on_write: assert arr.flags.writeable is False assert np.shares_memory(get_array(ser), arr) + + +def test_series_array_ea_dtypes(using_copy_on_write): + ser = Series([1, 2, 3], dtype="Int64") + arr = np.asarray(ser, dtype="int64") + assert np.shares_memory(arr, get_array(ser)) + if using_copy_on_write: + assert arr.flags.writeable is False + else: + assert arr.flags.writeable is True + + arr = np.asarray(ser) + assert not np.shares_memory(arr, get_array(ser)) + assert arr.flags.writeable is True + + +def test_dataframe_array_ea_dtypes(using_copy_on_write): + df = DataFrame({"a": [1, 2, 3]}, dtype="Int64") + arr = np.asarray(df, dtype="int64") + # TODO: This should be able to share memory, but we are roundtripping + # through object + assert not np.shares_memory(arr, get_array(df, "a")) + assert arr.flags.writeable is True + + arr = np.asarray(df) + if using_copy_on_write: + # TODO(CoW): This should be True + assert arr.flags.writeable is False + else: + assert arr.flags.writeable is True + + +def test_dataframe_array_string_dtype(using_copy_on_write, using_array_manager): + df = DataFrame({"a": ["a", "b"]}, dtype="string") + arr = np.asarray(df) + if not using_array_manager: + assert np.shares_memory(arr, get_array(df, "a")) + if using_copy_on_write: + assert arr.flags.writeable is False + else: + assert arr.flags.writeable is True + + +def test_dataframe_multiple_numpy_dtypes(): + df = DataFrame({"a": [1, 2, 3], "b": 1.5}) + arr = np.asarray(df) + assert not np.shares_memory(arr, get_array(df, "a")) + assert arr.flags.writeable is True + + +def test_values_is_ea(using_copy_on_write): + df = DataFrame({"a": date_range("2012-01-01", periods=3)}) + arr = np.asarray(df) + if using_copy_on_write: + assert arr.flags.writeable is False + else: + assert arr.flags.writeable is True + + +def test_empty_dataframe(): + df = DataFrame() + arr = np.asarray(df) + assert arr.flags.writeable is True