Skip to content

Commit 9d062eb

Browse files
committed
Vectorize groupby math
1 parent 7e14f10 commit 9d062eb

File tree

1 file changed

+56
-26
lines changed

1 file changed

+56
-26
lines changed

xarray/core/groupby.py

Lines changed: 56 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -259,6 +259,9 @@ class GroupBy:
259259
"_stacked_dim",
260260
"_unique_coord",
261261
"_dims",
262+
"_original_obj",
263+
"_original_group",
264+
"_bins",
262265
)
263266

264267
def __init__(
@@ -320,6 +323,11 @@ def __init__(
320323
if getattr(group, "name", None) is None:
321324
group.name = "group"
322325

326+
# save object before any stacking
327+
self._original_obj = obj
328+
self._original_group = group
329+
self._bins = bins
330+
323331
group, obj, stacked_dim, inserted_dims = _ensure_1d(group, obj)
324332
(group_dim,) = group.dims
325333

@@ -473,34 +481,56 @@ def _infer_concat_args(self, applied_example):
473481

474482
def _binary_op(self, other, f, reflexive=False):
475483
g = f if not reflexive else lambda x, y: f(y, x)
476-
applied = self._yield_binary_applied(g, other)
477-
return self._combine(applied)
478-
479-
def _yield_binary_applied(self, func, other):
480-
dummy = None
481484

482-
for group_value, obj in self:
483-
try:
484-
other_sel = other.sel(**{self._group.name: group_value})
485-
except AttributeError:
486-
raise TypeError(
487-
"GroupBy objects only support binary ops "
488-
"when the other argument is a Dataset or "
489-
"DataArray"
485+
obj = self._original_obj
486+
group = self._original_group
487+
name = group.name
488+
dim = self._group_dim
489+
# import IPython; IPython.core.debugger.set_trace()
490+
try:
491+
if self._bins is not None:
492+
other = other.sel({f"{name}_bins": self._group})
493+
if isinstance(group, _DummyGroup):
494+
# When binning by unindexed coordinate we need to reindex obj.
495+
# _full_index is IntervalIndex, so idx will be -1 where
496+
# a value does not belong to any bin. Using IntervalIndex
497+
# accounts for any non-default cut_kwargs passed to the constructor
498+
idx = pd.cut(obj[dim], bins=self._full_index).codes
499+
obj = obj.isel({dim: np.arange(group.size)[idx != -1]})
500+
else:
501+
if isinstance(group, _DummyGroup):
502+
group = obj[dim]
503+
other = other.sel({name: group})
504+
except AttributeError:
505+
raise TypeError(
506+
"GroupBy objects only support binary ops "
507+
"when the other argument is a Dataset or "
508+
"DataArray"
509+
)
510+
except (KeyError, ValueError):
511+
if name not in other.dims:
512+
raise ValueError(
513+
"incompatible dimensions for a grouped "
514+
f"binary operation: the group variable {name!r} "
515+
"is not a dimension on the other argument"
490516
)
491-
except (KeyError, ValueError):
492-
if self._group.name not in other.dims:
493-
raise ValueError(
494-
"incompatible dimensions for a grouped "
495-
f"binary operation: the group variable {self._group.name!r} "
496-
"is not a dimension on the other argument"
497-
)
498-
if dummy is None:
499-
dummy = _dummy_copy(other)
500-
other_sel = dummy
501-
502-
result = func(obj, other_sel)
503-
yield result
517+
# some labels are absent i.e. other is not aligned
518+
# so we align by reindexing and then rename dimensions.
519+
# TODO: probably need to copy some coordinates over
520+
other = (
521+
other.reindex({name: group.data})
522+
.rename({name: dim})
523+
.assign_coords({dim: obj[dim]})
524+
)
525+
526+
result = g(obj, other)
527+
528+
# backcompat: concat during the "combine" step places
529+
# `dim` as the first dimension
530+
if dim in result.dims:
531+
# guards against self._group_dim being "stacked"
532+
result = result.transpose(dim, ...)
533+
return result
504534

505535
def _maybe_restore_empty_groups(self, combined):
506536
"""Our index contained empty groups (e.g., from a resampling). If we

0 commit comments

Comments
 (0)