Skip to content

Commit 4be569e

Browse files
REFACTOR-#1879: Move logic for groupby.agg into query compiler
Signed-off-by: Devin Petersohn <devin.petersohn@gmail.com>
1 parent c43a580 commit 4be569e

File tree

2 files changed

+53
-3
lines changed

2 files changed

+53
-3
lines changed

modin/backends/pandas/query_compiler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2032,6 +2032,31 @@ def _callable_func(self, func, axis, *args, **kwargs):
20322032
lambda df, **kwargs: pandas.DataFrame(df.size()), lambda df, **kwargs: df.sum()
20332033
)
20342034

2035+
def groupby_dict_agg(self, by, func_dict, groupby_args, agg_args, drop=False):
2036+
"""Apply aggregation functions to a grouped dataframe per-column.
2037+
2038+
Parameters
2039+
----------
2040+
by : PandasQueryCompiler
2041+
The column to group by
2042+
func_dict : dict of str, callable/string
2043+
The dictionary mapping of column to function
2044+
groupby_args : dict
2045+
The dictionary of keyword arguments for the group by.
2046+
agg_args : dict
2047+
The dictionary of keyword arguments for the aggregation functions
2048+
drop : bool
2049+
Whether or not to drop the column from the data.
2050+
2051+
Returns
2052+
-------
2053+
PandasQueryCompiler
2054+
The result of the per-column aggregations on the grouped dataframe.
2055+
"""
2056+
return self.default_to_pandas(
2057+
lambda df: df.groupby(by=by, **groupby_args).agg(func_dict, **agg_args)
2058+
)
2059+
20352060
def groupby_agg(self, by, axis, agg_func, groupby_args, agg_args, drop=False):
20362061
# since we're going to modify `groupby_args` dict in a `groupby_agg_builder`,
20372062
# we want to copy it to not propagate these changes into source dict, in case

modin/pandas/groupby.py

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -345,19 +345,44 @@ def aggregate(self, func=None, *args, **kwargs):
345345
# This is not implemented in pandas,
346346
# so we throw a different message
347347
raise NotImplementedError("axis other than 0 is not supported")
348+
if isinstance(func, dict) or func is None:
349+
if func is None:
350+
func = {}
351+
else:
352+
if any(i not in self._df.columns for i in func.keys()):
353+
from pandas.core.base import SpecificationError
348354

349-
if func is None or is_list_like(func):
355+
raise SpecificationError("nested renamer is not supported")
356+
if isinstance(self._by, type(self._query_compiler)):
357+
by = list(self._by.columns)
358+
else:
359+
by = self._by
360+
# We convert to the string version of the function for simplicity.
361+
func_dict = {
362+
k: v if not callable(v) or v.__name__ not in dir(self) else v.__name__
363+
for k, v in func.items()
364+
}
365+
subset_cols = list(func_dict.keys()) + (
366+
list(self._by.columns)
367+
if isinstance(self._by, type(self._query_compiler))
368+
and all(c in self._df.columns for c in self._by.columns)
369+
else []
370+
)
371+
return type(self._df)(
372+
query_compiler=self._df[subset_cols]._query_compiler.groupby_dict_agg(
373+
by, func_dict, self._kwargs, kwargs, drop=self._drop
374+
)
375+
)
376+
if is_list_like(func):
350377
return self._default_to_pandas(
351378
lambda df, *args, **kwargs: df.aggregate(func, *args, **kwargs),
352379
*args,
353380
**kwargs,
354381
)
355-
356382
if isinstance(func, str):
357383
agg_func = getattr(self, func, None)
358384
if callable(agg_func):
359385
return agg_func(*args, **kwargs)
360-
361386
return self._apply_agg_function(
362387
lambda df, *args, **kwargs: df.aggregate(func, *args, **kwargs),
363388
drop=self._as_index,

0 commit comments

Comments
 (0)