Skip to content

Commit db6afe2

Browse files
committed
Introduce SchemaRDD#aggregate() for simple aggregations
rdd.aggregate(Sum('val)) is just shorthand for rdd.groupBy()(Sum('val)), but seems be more natural than doing a groupBy with no grouping expressions when you really just want an aggregation over all rows. Did not add a JavaSchemaRDD or Python API, as these seem to be lacking in several other methods like groupBy() already -- leaving that cleanup for future patches.
1 parent 55fddf9 commit db6afe2

File tree

2 files changed

+23
-2
lines changed

2 files changed

+23
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ import java.util.{Map => JMap}
5959
* // Importing the SQL context gives access to all the SQL functions and implicit conversions.
6060
* import sqlContext._
6161
*
62-
* val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i")))
62+
* val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
6363
* // Any RDD containing case classes can be registered as a table. The schema of the table is
6464
* // automatically inferred using scala reflection.
6565
* rdd.registerAsTable("records")
@@ -204,6 +204,19 @@ class SchemaRDD(
204204
new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan))
205205
}
206206

207+
/**
208+
* Performs an aggregation over all Rows in this RDD.
209+
*
210+
* {{{
211+
* schemaRDD.aggregate(Sum('sales) as 'totalSales)
212+
* }}}
213+
*
214+
* @group Query
215+
*/
216+
def aggregate(aggregateExprs: Expression*): SchemaRDD = {
217+
groupBy()(aggregateExprs: _*)
218+
}
219+
207220
/**
208221
* Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes
209222
* with the same name, for example, when performing self-joins.
@@ -281,7 +294,7 @@ class SchemaRDD(
281294
* supports features such as filter pushdown.
282295
*/
283296
@Experimental
284-
override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0)
297+
override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0)
285298

286299
/**
287300
* :: Experimental ::

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ class DslQuerySuite extends QueryTest {
3939
testData2.groupBy('a)('a, Sum('b)),
4040
Seq((1,3),(2,3),(3,3))
4141
)
42+
checkAnswer(
43+
testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
44+
9
45+
)
46+
checkAnswer(
47+
testData2.aggregate(Sum('b)),
48+
9
49+
)
4250
}
4351

4452
test("select *") {

0 commit comments

Comments
 (0)