Skip to content

Commit a0b66c4

Browse files
REFACTOR-#1879: Move logic for groupby.agg into query compiler
Signed-off-by: Devin Petersohn <devin.petersohn@gmail.com>
1 parent 790b196 commit a0b66c4

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

modin/backends/pandas/query_compiler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2192,6 +2192,31 @@ def _callable_func(self, func, axis, *args, **kwargs):
21922192
lambda df, **kwargs: pandas.DataFrame(df.size()), lambda df, **kwargs: df.sum()
21932193
)
21942194

2195+
def groupby_dict_agg(self, by, func_dict, groupby_args, agg_args, drop=False):
2196+
"""Apply aggregation functions to a grouped dataframe per-column.
2197+
2198+
Parameters
2199+
----------
2200+
by : PandasQueryCompiler
2201+
The column to group by
2202+
func_dict : dict of str, callable/string
2203+
The dictionary mapping of column to function
2204+
groupby_args : dict
2205+
The dictionary of keyword arguments for the group by.
2206+
agg_args : dict
2207+
The dictionary of keyword arguments for the aggregation functions
2208+
drop : bool
2209+
Whether or not to drop the column from the data.
2210+
2211+
Returns
2212+
-------
2213+
PandasQueryCompiler
2214+
The result of the per-column aggregations on the grouped dataframe.
2215+
"""
2216+
return self.default_to_pandas(
2217+
lambda df: df.groupby(by=by, **groupby_args).agg(func_dict, **agg_args)
2218+
)
2219+
21952220
def groupby_agg(self, by, axis, agg_func, groupby_args, agg_args, drop=False):
21962221
# since we're going to modify `groupby_args` dict in a `groupby_agg_builder`,
21972222
# we want to copy it to not propagate these changes into source dict, in case

modin/pandas/groupby.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -363,19 +363,44 @@ def aggregate(self, func=None, *args, **kwargs):
363363
# This is not implemented in pandas,
364364
# so we throw a different message
365365
raise NotImplementedError("axis other than 0 is not supported")
366+
if isinstance(func, dict) or func is None:
367+
if func is None:
368+
func = {}
369+
else:
370+
if any(i not in self._df.columns for i in func.keys()):
371+
from pandas.core.base import SpecificationError
366372

367-
if func is None or is_list_like(func):
373+
raise SpecificationError("nested renamer is not supported")
374+
if isinstance(self._by, type(self._query_compiler)):
375+
by = list(self._by.columns)
376+
else:
377+
by = self._by
378+
# We convert to the string version of the function for simplicity.
379+
func_dict = {
380+
k: v if not callable(v) or v.__name__ not in dir(self) else v.__name__
381+
for k, v in func.items()
382+
}
383+
subset_cols = list(func_dict.keys()) + (
384+
list(self._by.columns)
385+
if isinstance(self._by, type(self._query_compiler))
386+
and all(c in self._df.columns for c in self._by.columns)
387+
else []
388+
)
389+
return type(self._df)(
390+
query_compiler=self._df[subset_cols]._query_compiler.groupby_dict_agg(
391+
by, func_dict, self._kwargs, kwargs, drop=self._drop
392+
)
393+
)
394+
if is_list_like(func):
368395
return self._default_to_pandas(
369396
lambda df, *args, **kwargs: df.aggregate(func, *args, **kwargs),
370397
*args,
371398
**kwargs,
372399
)
373-
374400
if isinstance(func, str):
375401
agg_func = getattr(self, func, None)
376402
if callable(agg_func):
377403
return agg_func(*args, **kwargs)
378-
379404
return self._apply_agg_function(
380405
lambda df, *args, **kwargs: df.aggregate(func, *args, **kwargs),
381406
drop=self._as_index,

0 commit comments

Comments
 (0)