Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] Support named aggregations in df.groupby().agg() #16528

Merged
merged 16 commits into from
Aug 15, 2024
Merged
Prev Previous commit
Next Next commit
Address review
  • Loading branch information
Matt711 committed Aug 14, 2024
commit 107bfc749d1d498482eff19b4bc707a683d298d3
7 changes: 0 additions & 7 deletions python/cudf/cudf/core/column_accessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import itertools
import sys
import warnings
from collections import abc
from functools import cached_property, reduce
from typing import TYPE_CHECKING, Any, Callable, Mapping
Expand Down Expand Up @@ -654,12 +653,6 @@ def rename_column(x):
return x

if level is None:
warnings.warn(
"Renaming columns with MultiIndex assuming level=0. "
"Specify the level keyword argument to rename using "
"a different level eg. df.rename(..., level=1)",
UserWarning,
)
level = 0
new_col_names = (rename_column(k) for k in self.keys())

Expand Down
2 changes: 1 addition & 1 deletion python/cudf/cudf/core/groupby/groupby.py
Original file line number Diff line number Diff line change
Expand Up @@ -1298,7 +1298,7 @@ def _normalize_aggs(
aggs_per_column = (aggs,) * len(columns)
elif not aggs and kwargs:
column_names, aggs_per_column = kwargs.keys(), kwargs.values()
columns = tuple(self.obj._data[x[1][0]] for x in kwargs.items())
columns = tuple(self.obj._data[x[0]] for x in kwargs.values())
aggs_per_column = tuple(x[1] for x in kwargs.values())
else:
raise TypeError("Must provide at least one aggregation function.")
Expand Down
10 changes: 5 additions & 5 deletions python/cudf/cudf/tests/groupby/test_agg.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright (c) 2023-2024, NVIDIA CORPORATION.
import numpy as np
import pandas as pd
import pytest

import cudf
from cudf.testing import assert_eq


@pytest.mark.parametrize(
Expand Down Expand Up @@ -38,21 +38,21 @@ def test_dataframe_agg(attr, func):
agg = getattr(df.groupby("a"), attr)(func)
pd_agg = getattr(pdf.groupby(["a"]), attr)(func)

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)
assert_eq(agg, pd_agg)

agg = getattr(df.groupby("a"), attr)({"b": func})
pd_agg = getattr(pdf.groupby(["a"]), attr)({"b": func})

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)
assert_eq(agg, pd_agg)

agg = getattr(df.groupby("a"), attr)([func])
pd_agg = getattr(pdf.groupby(["a"]), attr)([func])

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)
assert_eq(agg, pd_agg)

agg = getattr(df.groupby("a"), attr)(foo=("b", func), bar=("a", func))
pd_agg = getattr(pdf.groupby(["a"]), attr)(
foo=("b", func), bar=("a", func)
)

pd.testing.assert_frame_equal(agg.to_pandas(), pd_agg)
assert_eq(agg, pd_agg)
Loading