Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Return a dataclass from Grouper.factorize #8777

Merged
merged 3 commits into from
Mar 15, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
112 changes: 73 additions & 39 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@
GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
T_GroupIndices = list[GroupIndex]
T_FactorizeOut = tuple[
DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index
]


def check_reduce_dims(reduce_dims, dimensions):
Expand Down Expand Up @@ -98,7 +95,7 @@ def _maybe_squeeze_indices(

def unique_value_groups(
ar, sort: bool = True
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
) -> tuple[np.ndarray | pd.Index, np.ndarray]:
"""Group an array by its unique values.

Parameters
Expand All @@ -119,11 +116,11 @@ def unique_value_groups(
inverse, values = pd.factorize(ar, sort=sort)
if isinstance(values, pd.MultiIndex):
values.names = ar.names
groups = _codes_to_groups(inverse, len(values))
return values, groups, inverse
return values, inverse


def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices:
assert inverse.ndim == 1
groups: T_GroupIndices = [[] for _ in range(N)]
for n, g in enumerate(inverse):
if g >= 0:
Expand Down Expand Up @@ -356,7 +353,7 @@ def can_squeeze(self) -> bool:
return False

@abstractmethod
def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
"""
Takes the group, and creates intermediates necessary for GroupBy.
These intermediates are
Expand All @@ -378,6 +375,27 @@ class Resampler(Grouper):
pass


@dataclass
class EncodedGroups:
"""
Dataclass for storing intermediate values for GroupBy operation.
Returned by factorize method on Grouper objects.

Parameters
----------
codes: integer codes for each group
full_index: pandas Index for the group coordinate
group_indices: optional, List of indices of array elements belonging
to each group. Inferred if not provided.
unique_coord: Unique group values present in dataset. Inferred if not provided
"""

codes: DataArray
full_index: pd.Index
group_indices: T_GroupIndices | None = field(default=None)
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)


@dataclass
class ResolvedGrouper(Generic[T_DataWithCoords]):
"""
Expand All @@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
group: T_Group
obj: T_DataWithCoords

# Defined by factorize:
# returned by factorize:
codes: DataArray = field(init=False)
full_index: pd.Index = field(init=False)
group_indices: T_GroupIndices = field(init=False)
unique_coord: IndexVariable | _DummyGroup = field(init=False)
full_index: pd.Index = field(init=False)

# _ensure_1d:
group1d: T_Group = field(init=False)
Expand Down Expand Up @@ -445,12 +463,26 @@ def dims(self):
return self.group1d.dims

def factorize(self) -> None:
(
self.codes,
self.group_indices,
self.unique_coord,
self.full_index,
) = self.grouper.factorize(self.group1d)
encoded = self.grouper.factorize(self.group1d)

self.codes = encoded.codes
self.full_index = encoded.full_index

if encoded.group_indices is not None:
self.group_indices = encoded.group_indices
else:
self.group_indices = [
g
for g in _codes_to_group_indices(self.codes.data, len(self.full_index))
if g
]
if encoded.unique_coord is None:
unique_values = self.full_index[np.unique(encoded.codes)]
self.unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
)
else:
self.unique_coord = encoded.unique_coord


@dataclass
Expand All @@ -477,34 +509,33 @@ def can_squeeze(self) -> bool:
is_dimension = self.group.dims == (self.group.name,)
return is_dimension and self.is_unique_and_monotonic

def factorize(self, group1d) -> T_FactorizeOut:
def factorize(self, group1d) -> EncodedGroups:
self.group = group1d

if self.can_squeeze:
return self._factorize_dummy()
else:
return self._factorize_unique()

def _factorize_unique(self) -> T_FactorizeOut:
def _factorize_unique(self) -> EncodedGroups:
# look through group to find the unique values
sort = not isinstance(self.group_as_index, pd.MultiIndex)
unique_values, group_indices, codes_ = unique_value_groups(
self.group_as_index, sort=sort
)
if len(group_indices) == 0:
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
if (codes_ == -1).all():
raise ValueError(
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
codes = self.group.copy(data=codes_)
group_indices = group_indices
unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
)
full_index = unique_coord

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)

def _factorize_dummy(self) -> T_FactorizeOut:
def _factorize_dummy(self) -> EncodedGroups:
size = self.group.size
# no need to factorize
# use slices to do views instead of fancy indexing
Expand All @@ -519,8 +550,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
full_index = IndexVariable(
self.group.name, unique_coord.values, self.group.attrs
)

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes,
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
)


@dataclass
Expand All @@ -536,7 +571,7 @@ def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
raise ValueError("All bin edges are NaN.")

def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
from xarray.core.dataarray import DataArray

data = group.data
Expand All @@ -554,20 +589,14 @@ def factorize(self, group) -> T_FactorizeOut:
full_index = binned.categories
uniques = np.sort(pd.unique(binned_codes))
unique_values = full_index[uniques[uniques != -1]]
group_indices = [
g for g in _codes_to_groups(binned_codes, len(full_index)) if g
]

if len(group_indices) == 0:
raise ValueError(
f"None of the data falls within bins with edges {self.bins!r}"
)

codes = DataArray(
binned_codes, getattr(group, "coords", None), name=new_dim_name
)
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)


@dataclass
Expand Down Expand Up @@ -672,7 +701,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
_apply_loffset(self.loffset, first_items)
return first_items, codes

def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
self._init_properties(group)
full_index, first_items, codes_ = self._get_index_and_items()
sbins = first_items.values.astype(np.int64)
Expand All @@ -684,7 +713,12 @@ def factorize(self, group) -> T_FactorizeOut:
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
codes = group.copy(data=codes_)

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes,
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
)


def _validate_groupby_squeeze(squeeze: bool | None) -> None:
Expand Down
Loading