Skip to content

Commit

Permalink
Return a dataclass from Grouper.factorize (#8777)
Browse files Browse the repository at this point in the history
* Return dataclass from factorize

* cleanup
  • Loading branch information
dcherian authored Mar 15, 2024
1 parent cb051d8 commit 3dcfa31
Showing 1 changed file with 73 additions and 39 deletions.
112 changes: 73 additions & 39 deletions xarray/core/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@
GroupKey = Any
GroupIndex = Union[int, slice, list[int]]
T_GroupIndices = list[GroupIndex]
T_FactorizeOut = tuple[
DataArray, T_GroupIndices, Union[pd.Index, "_DummyGroup"], pd.Index
]


def check_reduce_dims(reduce_dims, dimensions):
Expand Down Expand Up @@ -98,7 +95,7 @@ def _maybe_squeeze_indices(

def unique_value_groups(
ar, sort: bool = True
) -> tuple[np.ndarray | pd.Index, T_GroupIndices, np.ndarray]:
) -> tuple[np.ndarray | pd.Index, np.ndarray]:
"""Group an array by its unique values.
Parameters
Expand All @@ -119,11 +116,11 @@ def unique_value_groups(
inverse, values = pd.factorize(ar, sort=sort)
if isinstance(values, pd.MultiIndex):
values.names = ar.names
groups = _codes_to_groups(inverse, len(values))
return values, groups, inverse
return values, inverse


def _codes_to_groups(inverse: np.ndarray, N: int) -> T_GroupIndices:
def _codes_to_group_indices(inverse: np.ndarray, N: int) -> T_GroupIndices:
assert inverse.ndim == 1
groups: T_GroupIndices = [[] for _ in range(N)]
for n, g in enumerate(inverse):
if g >= 0:
Expand Down Expand Up @@ -356,7 +353,7 @@ def can_squeeze(self) -> bool:
return False

@abstractmethod
def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
"""
Takes the group, and creates intermediates necessary for GroupBy.
These intermediates are
Expand All @@ -378,6 +375,27 @@ class Resampler(Grouper):
pass


@dataclass
class EncodedGroups:
"""
Dataclass for storing intermediate values for GroupBy operation.
Returned by factorize method on Grouper objects.
Parameters
----------
codes: integer codes for each group
full_index: pandas Index for the group coordinate
group_indices: optional, List of indices of array elements belonging
to each group. Inferred if not provided.
unique_coord: Unique group values present in dataset. Inferred if not provided
"""

codes: DataArray
full_index: pd.Index
group_indices: T_GroupIndices | None = field(default=None)
unique_coord: IndexVariable | _DummyGroup | None = field(default=None)


@dataclass
class ResolvedGrouper(Generic[T_DataWithCoords]):
"""
Expand All @@ -397,11 +415,11 @@ class ResolvedGrouper(Generic[T_DataWithCoords]):
group: T_Group
obj: T_DataWithCoords

# Defined by factorize:
# returned by factorize:
codes: DataArray = field(init=False)
full_index: pd.Index = field(init=False)
group_indices: T_GroupIndices = field(init=False)
unique_coord: IndexVariable | _DummyGroup = field(init=False)
full_index: pd.Index = field(init=False)

# _ensure_1d:
group1d: T_Group = field(init=False)
Expand Down Expand Up @@ -445,12 +463,26 @@ def dims(self):
return self.group1d.dims

def factorize(self) -> None:
(
self.codes,
self.group_indices,
self.unique_coord,
self.full_index,
) = self.grouper.factorize(self.group1d)
encoded = self.grouper.factorize(self.group1d)

self.codes = encoded.codes
self.full_index = encoded.full_index

if encoded.group_indices is not None:
self.group_indices = encoded.group_indices
else:
self.group_indices = [
g
for g in _codes_to_group_indices(self.codes.data, len(self.full_index))
if g
]
if encoded.unique_coord is None:
unique_values = self.full_index[np.unique(encoded.codes)]
self.unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
)
else:
self.unique_coord = encoded.unique_coord


@dataclass
Expand All @@ -477,34 +509,33 @@ def can_squeeze(self) -> bool:
is_dimension = self.group.dims == (self.group.name,)
return is_dimension and self.is_unique_and_monotonic

def factorize(self, group1d) -> T_FactorizeOut:
def factorize(self, group1d) -> EncodedGroups:
self.group = group1d

if self.can_squeeze:
return self._factorize_dummy()
else:
return self._factorize_unique()

def _factorize_unique(self) -> T_FactorizeOut:
def _factorize_unique(self) -> EncodedGroups:
# look through group to find the unique values
sort = not isinstance(self.group_as_index, pd.MultiIndex)
unique_values, group_indices, codes_ = unique_value_groups(
self.group_as_index, sort=sort
)
if len(group_indices) == 0:
unique_values, codes_ = unique_value_groups(self.group_as_index, sort=sort)
if (codes_ == -1).all():
raise ValueError(
"Failed to group data. Are you grouping by a variable that is all NaN?"
)
codes = self.group.copy(data=codes_)
group_indices = group_indices
unique_coord = IndexVariable(
self.group.name, unique_values, attrs=self.group.attrs
)
full_index = unique_coord

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)

def _factorize_dummy(self) -> T_FactorizeOut:
def _factorize_dummy(self) -> EncodedGroups:
size = self.group.size
# no need to factorize
# use slices to do views instead of fancy indexing
Expand All @@ -519,8 +550,12 @@ def _factorize_dummy(self) -> T_FactorizeOut:
full_index = IndexVariable(
self.group.name, unique_coord.values, self.group.attrs
)

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes,
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
)


@dataclass
Expand All @@ -536,7 +571,7 @@ def __post_init__(self) -> None:
if duck_array_ops.isnull(self.bins).all():
raise ValueError("All bin edges are NaN.")

def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
from xarray.core.dataarray import DataArray

data = group.data
Expand All @@ -554,20 +589,14 @@ def factorize(self, group) -> T_FactorizeOut:
full_index = binned.categories
uniques = np.sort(pd.unique(binned_codes))
unique_values = full_index[uniques[uniques != -1]]
group_indices = [
g for g in _codes_to_groups(binned_codes, len(full_index)) if g
]

if len(group_indices) == 0:
raise ValueError(
f"None of the data falls within bins with edges {self.bins!r}"
)

codes = DataArray(
binned_codes, getattr(group, "coords", None), name=new_dim_name
)
unique_coord = IndexVariable(new_dim_name, pd.Index(unique_values), group.attrs)
return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes, full_index=full_index, unique_coord=unique_coord
)


@dataclass
Expand Down Expand Up @@ -672,7 +701,7 @@ def first_items(self) -> tuple[pd.Series, np.ndarray]:
_apply_loffset(self.loffset, first_items)
return first_items, codes

def factorize(self, group) -> T_FactorizeOut:
def factorize(self, group) -> EncodedGroups:
self._init_properties(group)
full_index, first_items, codes_ = self._get_index_and_items()
sbins = first_items.values.astype(np.int64)
Expand All @@ -684,7 +713,12 @@ def factorize(self, group) -> T_FactorizeOut:
unique_coord = IndexVariable(group.name, first_items.index, group.attrs)
codes = group.copy(data=codes_)

return codes, group_indices, unique_coord, full_index
return EncodedGroups(
codes=codes,
group_indices=group_indices,
full_index=full_index,
unique_coord=unique_coord,
)


def _validate_groupby_squeeze(squeeze: bool | None) -> None:
Expand Down

0 comments on commit 3dcfa31

Please sign in to comment.