diff --git a/doc/user-guide/groupby.rst b/doc/user-guide/groupby.rst index c10ee6a659d..98bd7b4833b 100644 --- a/doc/user-guide/groupby.rst +++ b/doc/user-guide/groupby.rst @@ -305,6 +305,12 @@ Use grouper objects to group by multiple dimensions: from xarray.groupers import UniqueGrouper + da.groupby(["lat", "lon"]).sum() + +The above is sugar for using ``UniqueGrouper`` objects directly: + +.. ipython:: python + da.groupby(lat=UniqueGrouper(), lon=UniqueGrouper()).sum() diff --git a/xarray/core/dataarray.py b/xarray/core/dataarray.py index f17bd057c03..e3ba92c21cd 100644 --- a/xarray/core/dataarray.py +++ b/xarray/core/dataarray.py @@ -103,6 +103,7 @@ Dims, ErrorOptions, ErrorOptionsWithWarn, + GroupInput, InterpOptions, PadModeOptions, PadReflectOptions, @@ -6707,9 +6708,7 @@ def interp_calendar( @_deprecate_positional_args("v2024.07.0") def groupby( self, - group: ( - Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None - ) = None, + group: GroupInput = None, *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, @@ -6719,7 +6718,7 @@ def groupby( Parameters ---------- - group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper Array whose unique values should be used to group this array. If a Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, must map an existing variable name to a :py:class:`Grouper` instance. @@ -6770,6 +6769,52 @@ def groupby( Coordinates: * dayofyear (dayofyear) int64 3kB 1 2 3 4 5 6 7 ... 361 362 363 364 365 366 + >>> da = xr.DataArray( + ... data=np.arange(12).reshape((4, 3)), + ... dims=("x", "y"), + ... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, + ... ) + + Grouping by a single variable is easy + + >>> da.groupby("letters") + + + Execute a reduction + + >>> da.groupby("letters").sum() + Size: 48B + array([[ 9., 11., 13.], + [ 9., 11., 13.]]) + Coordinates: + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + + Grouping by multiple variables + + >>> da.groupby(["letters", "x"]) + + + Use Grouper objects to express more complicated GroupBy operations + + >>> from xarray.groupers import BinGrouper, UniqueGrouper + >>> + >>> da.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() + Size: 96B + array([[[ 0., 1., 2.], + [nan, nan, nan]], + + [[nan, nan, nan], + [ 3., 4., 5.]]]) + Coordinates: + * x_bins (x_bins) object 16B (5, 15] (15, 25] + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + + See Also -------- :ref:`groupby` @@ -6791,32 +6836,12 @@ def groupby( """ from xarray.core.groupby import ( DataArrayGroupBy, - ResolvedGrouper, + _parse_group_and_groupers, _validate_groupby_squeeze, ) - from xarray.groupers import UniqueGrouper _validate_groupby_squeeze(squeeze) - - if isinstance(group, Mapping): - groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore - group = None - - rgroupers: tuple[ResolvedGrouper, ...] - if group is not None: - if groupers: - raise ValueError( - "Providing a combination of `group` and **groupers is not supported." - ) - rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),) - else: - if not groupers: - raise ValueError("Either `group` or `**groupers` must be provided.") - rgroupers = tuple( - ResolvedGrouper(grouper, group, self) - for group, grouper in groupers.items() - ) - + rgroupers = _parse_group_and_groupers(self, group, groupers) return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) @_deprecate_positional_args("v2024.07.0") diff --git a/xarray/core/dataset.py b/xarray/core/dataset.py index ae387da7e8e..a7b52dc0185 100644 --- a/xarray/core/dataset.py +++ b/xarray/core/dataset.py @@ -155,6 +155,7 @@ DsCompatible, ErrorOptions, ErrorOptionsWithWarn, + GroupInput, InterpOptions, JoinOptions, PadModeOptions, @@ -10332,9 +10333,7 @@ def interp_calendar( @_deprecate_positional_args("v2024.07.0") def groupby( self, - group: ( - Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None - ) = None, + group: GroupInput = None, *, squeeze: Literal[False] = False, restore_coord_dims: bool = False, @@ -10344,7 +10343,7 @@ def groupby( Parameters ---------- - group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper + group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper Array whose unique values should be used to group this array. If a Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary, must map an existing variable name to a :py:class:`Grouper` instance. @@ -10366,6 +10365,51 @@ def groupby( A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be iterated over in the form of `(unique_value, grouped_array)` pairs. + Examples + -------- + >>> ds = xr.Dataset( + ... {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))}, + ... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))}, + ... ) + + Grouping by a single variable is easy + + >>> ds.groupby("letters") + + + Execute a reduction + + >>> ds.groupby("letters").sum() + Size: 64B + Dimensions: (letters: 2, y: 3) + Coordinates: + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + Data variables: + foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0 + + Grouping by multiple variables + + >>> ds.groupby(["letters", "x"]) + + + Use Grouper objects to express more complicated GroupBy operations + + >>> from xarray.groupers import BinGrouper, UniqueGrouper + >>> + >>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum() + Size: 128B + Dimensions: (y: 3, x_bins: 2, letters: 2) + Coordinates: + * x_bins (x_bins) object 16B (5, 15] (15, 25] + * letters (letters) object 16B 'a' 'b' + Dimensions without coordinates: y + Data variables: + foo (y, x_bins, letters) float64 96B 0.0 nan nan 3.0 ... nan nan 5.0 + See Also -------- :ref:`groupby` @@ -10387,31 +10431,12 @@ def groupby( """ from xarray.core.groupby import ( DatasetGroupBy, - ResolvedGrouper, + _parse_group_and_groupers, _validate_groupby_squeeze, ) - from xarray.groupers import UniqueGrouper _validate_groupby_squeeze(squeeze) - - if isinstance(group, Mapping): - groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore - group = None - - rgroupers: tuple[ResolvedGrouper, ...] - if group is not None: - if groupers: - raise ValueError( - "Providing a combination of `group` and **groupers is not supported." - ) - rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),) - else: - if not groupers: - raise ValueError("Either `group` or `**groupers` must be provided.") - rgroupers = tuple( - ResolvedGrouper(grouper, group, self) - for group, grouper in groupers.items() - ) + rgroupers = _parse_group_and_groupers(self, group, groupers) return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims) diff --git a/xarray/core/groupby.py b/xarray/core/groupby.py index cc83b32adc8..d3bf33be0ca 100644 --- a/xarray/core/groupby.py +++ b/xarray/core/groupby.py @@ -6,7 +6,7 @@ import warnings from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence from dataclasses import dataclass, field -from typing import TYPE_CHECKING, Any, Generic, Literal, Union +from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast import numpy as np import pandas as pd @@ -54,7 +54,7 @@ from xarray.core.dataarray import DataArray from xarray.core.dataset import Dataset - from xarray.core.types import GroupIndex, GroupIndices, GroupKey + from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey from xarray.core.utils import Frozen from xarray.groupers import EncodedGroups, Grouper @@ -319,6 +319,51 @@ def __len__(self) -> int: return len(self.encoded.full_index) +def _parse_group_and_groupers( + obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper] +) -> tuple[ResolvedGrouper, ...]: + from xarray.core.dataarray import DataArray + from xarray.core.variable import Variable + from xarray.groupers import UniqueGrouper + + if group is not None and groupers: + raise ValueError( + "Providing a combination of `group` and **groupers is not supported." + ) + + if group is None and not groupers: + raise ValueError("Either `group` or `**groupers` must be provided.") + + if isinstance(group, np.ndarray | pd.Index): + raise TypeError( + f"`group` must be a DataArray. Received {type(group).__name__!r} instead" + ) + + if isinstance(group, Mapping): + grouper_mapping = either_dict_or_kwargs(group, groupers, "groupby") + group = None + + rgroupers: tuple[ResolvedGrouper, ...] + if isinstance(group, DataArray | Variable): + rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),) + else: + if group is not None: + if TYPE_CHECKING: + assert isinstance(group, str | Sequence) + group_iter: Sequence[Hashable] = ( + (group,) if isinstance(group, str) else group + ) + grouper_mapping = {g: UniqueGrouper() for g in group_iter} + elif groupers: + grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers) + + rgroupers = tuple( + ResolvedGrouper(grouper, group, obj) + for group, grouper in grouper_mapping.items() + ) + return rgroupers + + def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: # While we don't generally check the type of every arg, passing # multiple dimensions as multiple arguments is common enough, and the @@ -327,7 +372,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None: # A future version could make squeeze kwarg only, but would face # backward-compat issues. if squeeze is not False: - raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.") + raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.") def _resolve_group( @@ -626,7 +671,7 @@ def __repr__(self) -> str: for grouper in self.groupers: coord = grouper.unique_coord labels = ", ".join(format_array_flat(coord, 30).split()) - text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}" + text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}" return text + ">" def _iter_grouped(self) -> Iterator[T_Xarray]: diff --git a/xarray/core/types.py b/xarray/core/types.py index d3a8e7a9f4c..a9c2771cb9f 100644 --- a/xarray/core/types.py +++ b/xarray/core/types.py @@ -43,8 +43,17 @@ from xarray.core.dataset import Dataset from xarray.core.indexes import Index, Indexes from xarray.core.utils import Frozen - from xarray.core.variable import Variable - from xarray.groupers import TimeResampler + from xarray.core.variable import IndexVariable, Variable + from xarray.groupers import Grouper, TimeResampler + + GroupInput: TypeAlias = ( + str + | DataArray + | IndexVariable + | Sequence[Hashable] + | Mapping[Any, Grouper] + | None + ) try: from dask.array import Array as DaskArray diff --git a/xarray/tests/test_groupby.py b/xarray/tests/test_groupby.py index fc04b49fabc..b5d7312d9bb 100644 --- a/xarray/tests/test_groupby.py +++ b/xarray/tests/test_groupby.py @@ -583,7 +583,7 @@ def test_groupby_repr(obj, dim) -> None: N = len(np.unique(obj[dim])) expected = f"<{obj.__class__.__name__}GroupBy" expected += f", grouped over 1 grouper(s), {N} groups in total:" - expected += f"\n\t{dim!r}: {N} groups with labels " + expected += f"\n {dim!r}: {N} groups with labels " if dim == "x": expected += "1, 2, 3, 4, 5>" elif dim == "y": @@ -600,7 +600,7 @@ def test_groupby_repr_datetime(obj) -> None: actual = repr(obj.groupby("t.month")) expected = f"<{obj.__class__.__name__}GroupBy" expected += ", grouped over 1 grouper(s), 12 groups in total:\n" - expected += "\t'month': 12 groups with labels " + expected += " 'month': 12 groups with labels " expected += "1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12>" assert actual == expected @@ -2635,6 +2635,36 @@ def test_weather_data_resample(use_flox): assert expected.location.attrs == ds.location.attrs +@pytest.mark.parametrize("as_dataset", [True, False]) +def test_multiple_groupers_string(as_dataset) -> None: + obj = DataArray( + np.array([1, 2, 3, 0, 2, np.nan]), + dims="d", + coords=dict( + labels1=("d", np.array(["a", "b", "c", "c", "b", "a"])), + labels2=("d", np.array(["x", "y", "z", "z", "y", "x"])), + ), + name="foo", + ) + + if as_dataset: + obj = obj.to_dataset() # type: ignore + + expected = obj.groupby(labels1=UniqueGrouper(), labels2=UniqueGrouper()).mean() + actual = obj.groupby(("labels1", "labels2")).mean() + assert_identical(expected, actual) + + # Passes `"labels2"` to squeeze; will raise an error around kwargs rather than the + # warning & type error in the future + with pytest.warns(FutureWarning): + with pytest.raises(TypeError): + obj.groupby("labels1", "labels2") # type: ignore + with pytest.raises(ValueError): + obj.groupby("labels1", foo="bar") # type: ignore + with pytest.raises(ValueError): + obj.groupby("labels1", foo=UniqueGrouper()) + + @pytest.mark.parametrize("use_flox", [True, False]) def test_multiple_groupers(use_flox) -> None: da = DataArray(