Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GroupBy(multiple strings) #9414

Merged
merged 16 commits into from
Sep 4, 2024
6 changes: 6 additions & 0 deletions doc/user-guide/groupby.rst
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,12 @@ Use grouper objects to group by multiple dimensions:

from xarray.groupers import UniqueGrouper

da.groupby(["lat", "lon"]).sum()

The above is sugar for using ``UniqueGrouper`` objects directly:

.. ipython:: python

da.groupby(lat=UniqueGrouper(), lon=UniqueGrouper()).sum()


Expand Down
77 changes: 51 additions & 26 deletions xarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@
Dims,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
PadModeOptions,
PadReflectOptions,
Expand Down Expand Up @@ -6707,9 +6708,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
group: GroupInput = None,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
*,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
Expand All @@ -6719,7 +6718,7 @@ def groupby(

Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand Down Expand Up @@ -6770,6 +6769,52 @@ def groupby(
Coordinates:
* dayofyear (dayofyear) int64 3kB 1 2 3 4 5 6 7 ... 361 362 363 364 365 366

>>> da = xr.DataArray(
... data=np.arange(12).reshape((4, 3)),
... dims=("x", "y"),
... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
... )

Grouping by a single variable is easy

>>> da.groupby("letters")
<DataArrayGroupBy, grouped over 1 grouper(s), 2 groups in total:
'letters': 2 groups with labels 'a', 'b'>

Execute a reduction

>>> da.groupby("letters").sum()
<xarray.DataArray (letters: 2, y: 3)> Size: 48B
array([[ 9., 11., 13.],
[ 9., 11., 13.]])
Coordinates:
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y

Grouping by multiple variables

>>> da.groupby(["letters", "x"])
<DataArrayGroupBy, grouped over 2 grouper(s), 8 groups in total:
'letters': 2 groups with labels 'a', 'b'
'x': 4 groups with labels 10, 20, 30, 40>

Use Grouper objects to express more complicated GroupBy operations

>>> from xarray.groupers import BinGrouper, UniqueGrouper
>>>
>>> da.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
<xarray.DataArray (x_bins: 2, letters: 2, y: 3)> Size: 96B
array([[[ 0., 1., 2.],
[nan, nan, nan]],
<BLANKLINE>
[[nan, nan, nan],
[ 3., 4., 5.]]])
Coordinates:
* x_bins (x_bins) object 16B (5, 15] (15, 25]
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y


See Also
--------
:ref:`groupby`
Expand All @@ -6791,32 +6836,12 @@ def groupby(
"""
from xarray.core.groupby import (
DataArrayGroupBy,
ResolvedGrouper,
_parse_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
)

rgroupers = _parse_group_and_groupers(self, group, groupers)
return DataArrayGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

@_deprecate_positional_args("v2024.07.0")
Expand Down
75 changes: 50 additions & 25 deletions xarray/core/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@
DsCompatible,
ErrorOptions,
ErrorOptionsWithWarn,
GroupInput,
InterpOptions,
JoinOptions,
PadModeOptions,
Expand Down Expand Up @@ -10332,9 +10333,7 @@ def interp_calendar(
@_deprecate_positional_args("v2024.07.0")
def groupby(
self,
group: (
Hashable | DataArray | IndexVariable | Mapping[Any, Grouper] | None
) = None,
group: GroupInput = None,
dcherian marked this conversation as resolved.
Show resolved Hide resolved
*,
squeeze: Literal[False] = False,
restore_coord_dims: bool = False,
Expand All @@ -10344,7 +10343,7 @@ def groupby(

Parameters
----------
group : Hashable or DataArray or IndexVariable or mapping of Hashable to Grouper
group : str or DataArray or IndexVariable or sequence of hashable or mapping of hashable to Grouper
Array whose unique values should be used to group this array. If a
Hashable, must be the name of a coordinate contained in this dataarray. If a dictionary,
must map an existing variable name to a :py:class:`Grouper` instance.
Expand All @@ -10366,6 +10365,51 @@ def groupby(
A `DatasetGroupBy` object patterned after `pandas.GroupBy` that can be
iterated over in the form of `(unique_value, grouped_array)` pairs.

Examples
--------
>>> ds = xr.Dataset(
... {"foo": (("x", "y"), np.arange(12).reshape((4, 3)))},
... coords={"x": [10, 20, 30, 40], "letters": ("x", list("abba"))},
... )

Grouping by a single variable is easy

>>> ds.groupby("letters")
<DatasetGroupBy, grouped over 1 grouper(s), 2 groups in total:
'letters': 2 groups with labels 'a', 'b'>

Execute a reduction

>>> ds.groupby("letters").sum()
<xarray.Dataset> Size: 64B
Dimensions: (letters: 2, y: 3)
Coordinates:
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
Data variables:
foo (letters, y) float64 48B 9.0 11.0 13.0 9.0 11.0 13.0

Grouping by multiple variables

>>> ds.groupby(["letters", "x"])
<DatasetGroupBy, grouped over 2 grouper(s), 8 groups in total:
'letters': 2 groups with labels 'a', 'b'
'x': 4 groups with labels 10, 20, 30, 40>

Use Grouper objects to express more complicated GroupBy operations

>>> from xarray.groupers import BinGrouper, UniqueGrouper
>>>
>>> ds.groupby(x=BinGrouper(bins=[5, 15, 25]), letters=UniqueGrouper()).sum()
<xarray.Dataset> Size: 128B
Dimensions: (y: 3, x_bins: 2, letters: 2)
Coordinates:
* x_bins (x_bins) object 16B (5, 15] (15, 25]
* letters (letters) object 16B 'a' 'b'
Dimensions without coordinates: y
Data variables:
foo (y, x_bins, letters) float64 96B 0.0 nan nan 3.0 ... nan nan 5.0

See Also
--------
:ref:`groupby`
Expand All @@ -10387,31 +10431,12 @@ def groupby(
"""
from xarray.core.groupby import (
DatasetGroupBy,
ResolvedGrouper,
_parse_group_and_groupers,
_validate_groupby_squeeze,
)
from xarray.groupers import UniqueGrouper

_validate_groupby_squeeze(squeeze)

if isinstance(group, Mapping):
groupers = either_dict_or_kwargs(group, groupers, "groupby") # type: ignore
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if group is not None:
if groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, self),)
else:
if not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")
rgroupers = tuple(
ResolvedGrouper(grouper, group, self)
for group, grouper in groupers.items()
)
rgroupers = _parse_group_and_groupers(self, group, groupers)

return DatasetGroupBy(self, rgroupers, restore_coord_dims=restore_coord_dims)

Expand Down
53 changes: 49 additions & 4 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import warnings
from collections.abc import Callable, Hashable, Iterator, Mapping, Sequence
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Generic, Literal, Union
from typing import TYPE_CHECKING, Any, Generic, Literal, Union, cast

import numpy as np
import pandas as pd
Expand Down Expand Up @@ -54,7 +54,7 @@

from xarray.core.dataarray import DataArray
from xarray.core.dataset import Dataset
from xarray.core.types import GroupIndex, GroupIndices, GroupKey
from xarray.core.types import GroupIndex, GroupIndices, GroupInput, GroupKey
from xarray.core.utils import Frozen
from xarray.groupers import EncodedGroups, Grouper

Expand Down Expand Up @@ -319,6 +319,51 @@ def __len__(self) -> int:
return len(self.encoded.full_index)


def _parse_group_and_groupers(
obj: T_Xarray, group: GroupInput, groupers: dict[str, Grouper]
) -> tuple[ResolvedGrouper, ...]:
from xarray.core.dataarray import DataArray
from xarray.core.variable import Variable
from xarray.groupers import UniqueGrouper

if group is not None and groupers:
raise ValueError(
"Providing a combination of `group` and **groupers is not supported."
)

if group is None and not groupers:
raise ValueError("Either `group` or `**groupers` must be provided.")

if isinstance(group, np.ndarray | pd.Index):
raise TypeError(
f"`group` must be a DataArray. Received {type(group).__name__!r} instead"
)

if isinstance(group, Mapping):
grouper_mapping = either_dict_or_kwargs(group, groupers, "groupby")
group = None

rgroupers: tuple[ResolvedGrouper, ...]
if isinstance(group, DataArray | Variable):
rgroupers = (ResolvedGrouper(UniqueGrouper(), group, obj),)
else:
if group is not None:
if TYPE_CHECKING:
assert isinstance(group, str | Sequence)
group_iter: Sequence[Hashable] = (
(group,) if isinstance(group, str) else group
)
grouper_mapping = {g: UniqueGrouper() for g in group_iter}
elif groupers:
grouper_mapping = cast("Mapping[Hashable, Grouper]", groupers)

rgroupers = tuple(
ResolvedGrouper(grouper, group, obj)
for group, grouper in grouper_mapping.items()
)
return rgroupers


def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# While we don't generally check the type of every arg, passing
# multiple dimensions as multiple arguments is common enough, and the
Expand All @@ -327,7 +372,7 @@ def _validate_groupby_squeeze(squeeze: Literal[False]) -> None:
# A future version could make squeeze kwarg only, but would face
# backward-compat issues.
if squeeze is not False:
raise TypeError(f"`squeeze` must be False, but {squeeze} was supplied.")
raise TypeError(f"`squeeze` must be False, but {squeeze!r} was supplied.")


def _resolve_group(
Expand Down Expand Up @@ -626,7 +671,7 @@ def __repr__(self) -> str:
for grouper in self.groupers:
coord = grouper.unique_coord
labels = ", ".join(format_array_flat(coord, 30).split())
text += f"\n\t{grouper.name!r}: {coord.size} groups with labels {labels}"
text += f"\n {grouper.name!r}: {coord.size} groups with labels {labels}"
return text + ">"

def _iter_grouped(self) -> Iterator[T_Xarray]:
Expand Down
13 changes: 11 additions & 2 deletions xarray/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,17 @@
from xarray.core.dataset import Dataset
from xarray.core.indexes import Index, Indexes
from xarray.core.utils import Frozen
from xarray.core.variable import Variable
from xarray.groupers import TimeResampler
from xarray.core.variable import IndexVariable, Variable
from xarray.groupers import Grouper, TimeResampler

GroupInput: TypeAlias = (
str
| DataArray
| IndexVariable
| Sequence[Hashable]
| Mapping[Any, Grouper]
| None
)

try:
from dask.array import Array as DaskArray
Expand Down
Loading
Loading