Skip to content

Commit 2f1751d

Browse files
ilan-goldshoyerdcherianspencerkclarkkmuehlbauer
authored
Support extension array indexes (#9671)
Co-authored-by: Stephan Hoyer <shoyer@google.com> Co-authored-by: Deepak Cherian <dcherian@users.noreply.github.com> Co-authored-by: Spencer Clark <spencerkclark@gmail.com> Co-authored-by: Kai Mühlbauer <kmuehlbauer@wradlib.org> Co-authored-by: Benoit Bovy <benbovy@gmail.com> Co-authored-by: Kai Muehlbauer <kai.muehlbauer@uni-bonn.de> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Deepak Cherian <deepak@cherian.net>
1 parent 7426ce7 commit 2f1751d

File tree

13 files changed

+150
-81
lines changed

13 files changed

+150
-81
lines changed

doc/whats-new.rst

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ New Features
2929
- Improved compatibility with OPeNDAP DAP4 data model for backend engine ``pydap``. This
3030
includes ``datatree`` support, and removing slashes from dimension names. By
3131
`Miguel Jimenez-Urias <https://github.com/Mikejmnez>`_.
32+
- Improved support pandas Extension Arrays. (:issue:`9661`, :pull:`9671`)
33+
By `Ilan Gold <https://github.com/ilan-gold>`_.
34+
3235

3336
Breaking changes
3437
~~~~~~~~~~~~~~~~
@@ -41,6 +44,12 @@ Breaking changes
4144
pydap 3.4 3.5.0
4245
===================== ========= =======
4346

47+
48+
- Reductions with ``groupby_bins`` or those that involve :py:class:`xarray.groupers.BinGrouper`
49+
now return objects indexed by :py:meth:`pandas.IntervalArray` objects,
50+
instead of numpy object arrays containing tuples. This change enables interval-aware indexing of
51+
such Xarray objects. (:pull:`9671`). By `Ilan Gold <https://github.com/ilan-gold>`_.
52+
4453
Deprecations
4554
~~~~~~~~~~~~
4655

xarray/core/dataarray.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6917,7 +6917,7 @@ def groupby(
69176917
[[nan, nan, nan],
69186918
[ 3., 4., 5.]]])
69196919
Coordinates:
6920-
* x_bins (x_bins) object 16B (5, 15] (15, 25]
6920+
* x_bins (x_bins) interval[int64, right] 32B (5, 15] (15, 25]
69216921
* letters (letters) object 16B 'a' 'b'
69226922
Dimensions without coordinates: y
69236923

xarray/core/dataset.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7059,6 +7059,8 @@ def to_pandas(self) -> pd.Series | pd.DataFrame:
70597059
)
70607060

70617061
def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
7062+
from xarray.core.extension_array import PandasExtensionArray
7063+
70627064
columns_in_order = [k for k in self.variables if k not in self.dims]
70637065
non_extension_array_columns = [
70647066
k
@@ -7070,20 +7072,41 @@ def _to_dataframe(self, ordered_dims: Mapping[Any, int]):
70707072
for k in columns_in_order
70717073
if is_extension_array_dtype(self.variables[k].data)
70727074
]
7075+
extension_array_columns_different_index = [
7076+
k
7077+
for k in extension_array_columns
7078+
if set(self.variables[k].dims) != set(ordered_dims.keys())
7079+
]
7080+
extension_array_columns_same_index = [
7081+
k
7082+
for k in extension_array_columns
7083+
if k not in extension_array_columns_different_index
7084+
]
70737085
data = [
70747086
self._variables[k].set_dims(ordered_dims).values.reshape(-1)
70757087
for k in non_extension_array_columns
70767088
]
70777089
index = self.coords.to_index([*ordered_dims])
70787090
broadcasted_df = pd.DataFrame(
7079-
dict(zip(non_extension_array_columns, data, strict=True)), index=index
7091+
{
7092+
**dict(zip(non_extension_array_columns, data, strict=True)),
7093+
**{
7094+
c: self.variables[c].data.array
7095+
for c in extension_array_columns_same_index
7096+
},
7097+
},
7098+
index=index,
70807099
)
7081-
for extension_array_column in extension_array_columns:
7100+
for extension_array_column in extension_array_columns_different_index:
70827101
extension_array = self.variables[extension_array_column].data.array
7083-
index = self[self.variables[extension_array_column].dims[0]].data
7102+
index = self[
7103+
self.variables[extension_array_column].dims[0]
7104+
].coords.to_index()
70847105
extension_array_df = pd.DataFrame(
70857106
{extension_array_column: extension_array},
7086-
index=self[self.variables[extension_array_column].dims[0]].data,
7107+
index=pd.Index(index.array)
7108+
if isinstance(index, PandasExtensionArray)
7109+
else index,
70877110
)
70887111
extension_array_df.index.name = self.variables[extension_array_column].dims[
70897112
0
@@ -9892,10 +9915,10 @@ def groupby(
98929915
>>> from xarray.groupers import BinGrouper, UniqueGrouper
98939916
>>>
98949917
>>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
9895-
<xarray.Dataset> Size: 128B
9918+
<xarray.Dataset> Size: 144B
98969919
Dimensions: (y: 3, x_bins: 2, letters: 2)
98979920
Coordinates:
9898-
* x_bins (x_bins) object 16B (5, 15] (15, 25]
9921+
* x_bins (x_bins) interval[int64, right] 32B (5, 15] (15, 25]
98999922
* letters (letters) object 16B 'a' 'b'
99009923
Dimensions without coordinates: y
99019924
Data variables:

xarray/core/extension_array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ def replace_duck_with_extension_array(args) -> list:
102102
return type(self)[type(res)](res)
103103
return res
104104

105-
def __array_ufunc__(ufunc, method, *inputs, **kwargs):
105+
def __array_ufunc__(self, ufunc, method, *inputs, **kwargs):
106106
return ufunc(*inputs, **kwargs)
107107

108108
def __repr__(self):

xarray/core/indexes.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111

1212
from xarray.core import formatting, nputils, utils
1313
from xarray.core.coordinate_transform import CoordinateTransform
14+
from xarray.core.extension_array import PandasExtensionArray
1415
from xarray.core.indexing import (
1516
CoordinateTransformIndexingAdapter,
1617
IndexSelResult,
@@ -444,6 +445,8 @@ def safe_cast_to_index(array: Any) -> pd.Index:
444445
from xarray.core.variable import Variable
445446
from xarray.namedarray.pycompat import to_numpy
446447

448+
if isinstance(array, PandasExtensionArray):
449+
array = pd.Index(array.array)
447450
if isinstance(array, pd.Index):
448451
index = array
449452
elif isinstance(array, DataArray | Variable):
@@ -602,7 +605,11 @@ def __init__(
602605
self.dim = dim
603606

604607
if coord_dtype is None:
605-
coord_dtype = get_valid_numpy_dtype(index)
608+
if pd.api.types.is_extension_array_dtype(index.dtype):
609+
cast(pd.api.extensions.ExtensionDtype, index.dtype)
610+
coord_dtype = index.dtype
611+
else:
612+
coord_dtype = get_valid_numpy_dtype(index)
606613
self.coord_dtype = coord_dtype
607614

608615
def _replace(self, index, dim=None, coord_dtype=None):
@@ -698,6 +705,8 @@ def concat(
698705

699706
if not indexes:
700707
coord_dtype = None
708+
elif len(set(idx.coord_dtype for idx in indexes)) == 1:
709+
coord_dtype = indexes[0].coord_dtype
701710
else:
702711
coord_dtype = np.result_type(*[idx.coord_dtype for idx in indexes])
703712

xarray/core/indexing.py

Lines changed: 52 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,14 +10,16 @@
1010
from dataclasses import dataclass, field
1111
from datetime import timedelta
1212
from html import escape
13-
from typing import TYPE_CHECKING, Any, overload
13+
from typing import TYPE_CHECKING, Any, cast, overload
1414

1515
import numpy as np
1616
import pandas as pd
17+
from numpy.typing import DTypeLike
1718
from packaging.version import Version
1819

1920
from xarray.core import duck_array_ops
2021
from xarray.core.coordinate_transform import CoordinateTransform
22+
from xarray.core.extension_array import PandasExtensionArray
2123
from xarray.core.nputils import NumpyVIndexAdapter
2224
from xarray.core.options import OPTIONS
2325
from xarray.core.types import T_Xarray
@@ -28,14 +30,13 @@
2830
is_duck_array,
2931
is_duck_dask_array,
3032
is_scalar,
33+
is_valid_numpy_dtype,
3134
to_0d_array,
3235
)
3336
from xarray.namedarray.parallelcompat import get_chunked_array_type
3437
from xarray.namedarray.pycompat import array_type, integer_types, is_chunked_array
3538

3639
if TYPE_CHECKING:
37-
from numpy.typing import DTypeLike
38-
3940
from xarray.core.indexes import Index
4041
from xarray.core.types import Self
4142
from xarray.core.variable import Variable
@@ -1744,27 +1745,43 @@ class PandasIndexingAdapter(ExplicitlyIndexedNDArrayMixin):
17441745
__slots__ = ("_dtype", "array")
17451746

17461747
array: pd.Index
1747-
_dtype: np.dtype
1748+
_dtype: np.dtype | pd.api.extensions.ExtensionDtype
17481749

1749-
def __init__(self, array: pd.Index, dtype: DTypeLike = None):
1750+
def __init__(
1751+
self,
1752+
array: pd.Index,
1753+
dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None,
1754+
):
17501755
from xarray.core.indexes import safe_cast_to_index
17511756

17521757
self.array = safe_cast_to_index(array)
17531758

17541759
if dtype is None:
1755-
self._dtype = get_valid_numpy_dtype(array)
1760+
if pd.api.types.is_extension_array_dtype(array.dtype):
1761+
cast(pd.api.extensions.ExtensionDtype, array.dtype)
1762+
self._dtype = array.dtype
1763+
else:
1764+
self._dtype = get_valid_numpy_dtype(array)
1765+
elif pd.api.types.is_extension_array_dtype(dtype):
1766+
self._dtype = cast(pd.api.extensions.ExtensionDtype, dtype)
17561767
else:
1757-
self._dtype = np.dtype(dtype)
1768+
self._dtype = np.dtype(cast(DTypeLike, dtype))
17581769

17591770
@property
1760-
def dtype(self) -> np.dtype:
1771+
def dtype(self) -> np.dtype | pd.api.extensions.ExtensionDtype: # type: ignore[override]
17611772
return self._dtype
17621773

17631774
def __array__(
1764-
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
1775+
self,
1776+
dtype: np.typing.DTypeLike | None = None,
1777+
/,
1778+
*,
1779+
copy: bool | None = None,
17651780
) -> np.ndarray:
1766-
if dtype is None:
1767-
dtype = self.dtype
1781+
if dtype is None and is_valid_numpy_dtype(self.dtype):
1782+
dtype = cast(np.dtype, self.dtype)
1783+
else:
1784+
dtype = get_valid_numpy_dtype(self.array)
17681785
array = self.array
17691786
if isinstance(array, pd.PeriodIndex):
17701787
with suppress(AttributeError):
@@ -1776,14 +1793,18 @@ def __array__(
17761793
else:
17771794
return np.asarray(array.values, dtype=dtype)
17781795

1779-
def get_duck_array(self) -> np.ndarray:
1796+
def get_duck_array(self) -> np.ndarray | PandasExtensionArray:
1797+
# We return an PandasExtensionArray wrapper type that satisfies
1798+
# duck array protocols. This is what's needed for tests to pass.
1799+
if pd.api.types.is_extension_array_dtype(self.array):
1800+
return PandasExtensionArray(self.array.array)
17801801
return np.asarray(self)
17811802

17821803
@property
17831804
def shape(self) -> _Shape:
17841805
return (len(self.array),)
17851806

1786-
def _convert_scalar(self, item):
1807+
def _convert_scalar(self, item) -> np.ndarray:
17871808
if item is pd.NaT:
17881809
# work around the impossibility of casting NaT with asarray
17891810
# note: it probably would be better in general to return
@@ -1799,7 +1820,10 @@ def _convert_scalar(self, item):
17991820
# numpy fails to convert pd.Timestamp to np.datetime64[ns]
18001821
item = np.asarray(item.to_datetime64())
18011822
elif self.dtype != object:
1802-
item = np.asarray(item, dtype=self.dtype)
1823+
dtype = self.dtype
1824+
if pd.api.types.is_extension_array_dtype(dtype):
1825+
dtype = get_valid_numpy_dtype(self.array)
1826+
item = np.asarray(item, dtype=cast(np.dtype, dtype))
18031827

18041828
# as for numpy.ndarray indexing, we always want the result to be
18051829
# a NumPy array.
@@ -1902,6 +1926,12 @@ def copy(self, deep: bool = True) -> Self:
19021926
array = self.array.copy(deep=True) if deep else self.array
19031927
return type(self)(array, self._dtype)
19041928

1929+
@property
1930+
def nbytes(self) -> int:
1931+
if pd.api.types.is_extension_array_dtype(self.dtype):
1932+
return self.array.nbytes
1933+
return cast(np.dtype, self.dtype).itemsize * len(self.array)
1934+
19051935

19061936
class PandasMultiIndexingAdapter(PandasIndexingAdapter):
19071937
"""Handles explicit indexing for a pandas.MultiIndex.
@@ -1914,23 +1944,27 @@ class PandasMultiIndexingAdapter(PandasIndexingAdapter):
19141944
__slots__ = ("_dtype", "adapter", "array", "level")
19151945

19161946
array: pd.MultiIndex
1917-
_dtype: np.dtype
1947+
_dtype: np.dtype | pd.api.extensions.ExtensionDtype
19181948
level: str | None
19191949

19201950
def __init__(
19211951
self,
19221952
array: pd.MultiIndex,
1923-
dtype: DTypeLike = None,
1953+
dtype: DTypeLike | pd.api.extensions.ExtensionDtype | None = None,
19241954
level: str | None = None,
19251955
):
19261956
super().__init__(array, dtype)
19271957
self.level = level
19281958

19291959
def __array__(
1930-
self, dtype: np.typing.DTypeLike = None, /, *, copy: bool | None = None
1960+
self,
1961+
dtype: DTypeLike | None = None,
1962+
/,
1963+
*,
1964+
copy: bool | None = None,
19311965
) -> np.ndarray:
19321966
if dtype is None:
1933-
dtype = self.dtype
1967+
dtype = cast(np.dtype, self.dtype)
19341968
if self.level is not None:
19351969
return np.asarray(
19361970
self.array.get_level_values(self.level).values, dtype=dtype

xarray/core/variable.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import numpy as np
1414
import pandas as pd
1515
from numpy.typing import ArrayLike
16-
from pandas.api.types import is_extension_array_dtype
1716

1817
import xarray as xr # only for Dataset and DataArray
1918
from xarray.compat.array_api_compat import to_like_array
@@ -60,6 +59,7 @@
6059
indexing.ExplicitlyIndexed,
6160
pd.Index,
6261
pd.api.extensions.ExtensionArray,
62+
PandasExtensionArray,
6363
)
6464
# https://github.com/python/mypy/issues/224
6565
BASIC_INDEXING_TYPES = integer_types + (slice,)
@@ -192,7 +192,7 @@ def _maybe_wrap_data(data):
192192
if isinstance(data, pd.Index):
193193
return PandasIndexingAdapter(data)
194194
if isinstance(data, pd.api.extensions.ExtensionArray):
195-
return PandasExtensionArray[type(data)](data)
195+
return PandasExtensionArray(data)
196196
return data
197197

198198

@@ -2593,11 +2593,6 @@ def chunk( # type: ignore[override]
25932593
dask.array.from_array
25942594
"""
25952595

2596-
if is_extension_array_dtype(self):
2597-
raise ValueError(
2598-
f"{self} was found to be a Pandas ExtensionArray. Please convert to numpy first."
2599-
)
2600-
26012596
if from_array_kwargs is None:
26022597
from_array_kwargs = {}
26032598

xarray/namedarray/core.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -834,6 +834,7 @@ def chunk(
834834
if chunkmanager.is_chunked_array(data_old):
835835
data_chunked = chunkmanager.rechunk(data_old, chunks) # type: ignore[arg-type]
836836
else:
837+
ndata: duckarray[Any, Any]
837838
if not isinstance(data_old, ExplicitlyIndexed):
838839
ndata = data_old
839840
else:

xarray/plot/dataarray_plot.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,7 @@ def line(
524524
assert hueplt is not None
525525
ax.legend(handles=primitive, labels=list(hueplt.to_numpy()), title=hue_label)
526526

527-
if np.issubdtype(xplt.dtype, np.datetime64):
527+
if isinstance(xplt.dtype, np.dtype) and np.issubdtype(xplt.dtype, np.datetime64): # type: ignore[redundant-expr]
528528
_set_concise_date(ax, axis="x")
529529

530530
_update_axes(ax, xincrease, yincrease, xscale, yscale, xticks, yticks, xlim, ylim)

xarray/tests/test_dataset.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4381,7 +4381,6 @@ def test_setitem_pandas(self) -> None:
43814381
ds["x"] = np.arange(3)
43824382
ds_copy = ds.copy()
43834383
ds_copy["bar"] = ds["bar"].to_pandas()
4384-
43854384
assert_equal(ds, ds_copy)
43864385

43874386
def test_setitem_auto_align(self) -> None:
@@ -4972,6 +4971,16 @@ def test_to_and_from_dataframe(self) -> None:
49724971
expected = pd.DataFrame([[]], index=idx)
49734972
assert expected.equals(actual), (expected, actual)
49744973

4974+
def test_from_dataframe_categorical_dtype_index(self) -> None:
4975+
cat = pd.CategoricalIndex(list("abcd"))
4976+
df = pd.DataFrame({"f": [0, 1, 2, 3]}, index=cat)
4977+
ds = df.to_xarray()
4978+
restored = ds.to_dataframe()
4979+
df.index.name = (
4980+
"index" # restored gets the name because it has the coord with the name
4981+
)
4982+
pd.testing.assert_frame_equal(df, restored)
4983+
49754984
def test_from_dataframe_categorical_index(self) -> None:
49764985
cat = pd.CategoricalDtype(
49774986
categories=["foo", "bar", "baz", "qux", "quux", "corge"]
@@ -4996,7 +5005,7 @@ def test_from_dataframe_categorical_index_string_categories(self) -> None:
49965005
)
49975006
ser = pd.Series(1, index=cat)
49985007
ds = ser.to_xarray()
4999-
assert ds.coords.dtypes["index"] == np.dtype("O")
5008+
assert ds.coords.dtypes["index"] == ser.index.dtype
50005009

50015010
@requires_sparse
50025011
def test_from_dataframe_sparse(self) -> None:

0 commit comments

Comments
 (0)