Skip to content

Commit 7cc2cbc

Browse files
committed
#58 Add support for df.agg()
### What changes were proposed in this pull request? Add support for `df.Agg() and `df.AggWithMap()`. ### Why are the changes needed? Compatibility ### Does this PR introduce _any_ user-facing change? New functions ### How was this patch tested? Added test Closes #100 from grundprinzip/df_agg. Authored-by: Martin Grund <martin.grund@databricks.com> Signed-off-by: Martin Grund <martin.grund@databricks.com>
1 parent 229c570 commit 7cc2cbc

File tree

3 files changed

+39
-2
lines changed

3 files changed

+39
-2
lines changed

internal/tests/integration/functions_test.go

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,25 @@ func TestIntegration_BuiltinFunctions(t *testing.T) {
4141
assert.Equal(t, 10, len(res))
4242
}
4343

44+
func TestAggregationFunctions_Agg(t *testing.T) {
45+
ctx, spark := connect()
46+
df, err := spark.Sql(ctx, "select id, 1, 2, 3 from range(100)")
47+
assert.NoError(t, err)
48+
49+
res, err := df.Agg(ctx, functions.Count(functions.Col("id")))
50+
assert.NoError(t, err)
51+
cnt, err := res.Count(ctx)
52+
assert.NoError(t, err)
53+
assert.Equal(t, int64(1), cnt)
54+
55+
res, err = df.AggWithMap(ctx, map[string]string{"id": "sum"})
56+
assert.NoError(t, err)
57+
rows, err := res.Collect(ctx)
58+
assert.NoError(t, err)
59+
assert.Len(t, rows, 1)
60+
assert.Equal(t, int64(4950), rows[0].At(0))
61+
}
62+
4463
func TestIntegration_ColumnGetItem(t *testing.T) {
4564
ctx := context.Background()
4665
spark, err := sql.NewSessionBuilder().Remote("sc://localhost").Build(ctx)

spark/sql/dataframe.go

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ type ResultCollector interface {
4141
type DataFrame interface {
4242
// PlanId returns the plan id of the data frame.
4343
PlanId() int64
44+
Agg(ctx context.Context, exprs ...column.Convertible) (DataFrame, error)
45+
AggWithMap(ctx context.Context, exprs map[string]string) (DataFrame, error)
4446
// Alias creates a new DataFrame with the specified subquery alias
4547
Alias(ctx context.Context, alias string) DataFrame
4648
// Cache persists the DataFrame with the default storage level.
@@ -1542,6 +1544,22 @@ func (df *dataFrameImpl) FillNaWithValues(ctx context.Context,
15421544
return makeDataframeWithFillNaRelation(df, valueLiterals, columns), nil
15431545
}
15441546

1547+
func (df *dataFrameImpl) Agg(ctx context.Context, cols ...column.Convertible) (DataFrame, error) {
1548+
return df.GroupBy().Agg(ctx, cols...)
1549+
}
1550+
1551+
func (df *dataFrameImpl) AggWithMap(ctx context.Context, exprs map[string]string) (DataFrame, error) {
1552+
funs := make([]column.Convertible, 0)
1553+
for k, v := range exprs {
1554+
// Convert the column name to a column expression.
1555+
col := column.OfDF(df, k)
1556+
// Convert the value string to an unresolved function name.
1557+
fun := column.NewUnresolvedFunctionWithColumns(v, col)
1558+
funs = append(funs, fun)
1559+
}
1560+
return df.Agg(ctx, funs...)
1561+
}
1562+
15451563
func (df *dataFrameImpl) DropNa(ctx context.Context, subset ...string) (DataFrame, error) {
15461564
rel := &proto.Relation{
15471565
Common: &proto.RelationCommon{

spark/sql/group.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ type GroupedData struct {
3737

3838
// Agg compute aggregates and returns the result as a DataFrame. The aggegrate expressions
3939
// are passed as column.Column arguments.
40-
func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Column) (DataFrame, error) {
40+
func (gd *GroupedData) Agg(ctx context.Context, exprs ...column.Convertible) (DataFrame, error) {
4141
if len(exprs) == 0 {
4242
return nil, sparkerrors.WithString(sparkerrors.InvalidInputError, "exprs should not be empty")
4343
}
@@ -144,7 +144,7 @@ func (gd *GroupedData) numericAgg(ctx context.Context, name string, cols ...stri
144144
aggCols = numericCols
145145
}
146146

147-
finalColumns := make([]column.Column, len(aggCols))
147+
finalColumns := make([]column.Convertible, len(aggCols))
148148
for i, col := range aggCols {
149149
finalColumns[i] = column.NewColumn(column.NewUnresolvedFunctionWithColumns(name, functions.Col(col)))
150150
}

0 commit comments

Comments
 (0)