Skip to content

Commit c3576ff

Browse files
aarondavrxin
authored andcommitted
[SQL] Minor: Introduce SchemaRDD#aggregate() for simple aggregations
```scala rdd.aggregate(Sum('val)) ``` is just shorthand for ```scala rdd.groupBy()(Sum('val)) ``` but seems be more natural than doing a groupBy with no grouping expressions when you really just want an aggregation over all rows. Did not add a JavaSchemaRDD or Python API, as these seem to be lacking several other methods like groupBy() already -- leaving that cleanup for future patches. Author: Aaron Davidson <aaron@databricks.com> Closes #874 from aarondav/schemardd and squashes the following commits: e9e68ee [Aaron Davidson] Add comment db6afe2 [Aaron Davidson] Introduce SchemaRDD#aggregate() for simple aggregations
1 parent 0659529 commit c3576ff

File tree

2 files changed

+24
-2
lines changed

2 files changed

+24
-2
lines changed

sql/core/src/main/scala/org/apache/spark/sql/SchemaRDD.scala

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ import java.util.{Map => JMap}
5959
* // Importing the SQL context gives access to all the SQL functions and implicit conversions.
6060
* import sqlContext._
6161
*
62-
* val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_\$i")))
62+
* val rdd = sc.parallelize((1 to 100).map(i => Record(i, s"val_$i")))
6363
* // Any RDD containing case classes can be registered as a table. The schema of the table is
6464
* // automatically inferred using scala reflection.
6565
* rdd.registerAsTable("records")
@@ -204,6 +204,20 @@ class SchemaRDD(
204204
new SchemaRDD(sqlContext, Aggregate(groupingExprs, aliasedExprs, logicalPlan))
205205
}
206206

207+
/**
208+
* Performs an aggregation over all Rows in this RDD.
209+
* This is equivalent to a groupBy with no grouping expressions.
210+
*
211+
* {{{
212+
* schemaRDD.aggregate(Sum('sales) as 'totalSales)
213+
* }}}
214+
*
215+
* @group Query
216+
*/
217+
def aggregate(aggregateExprs: Expression*): SchemaRDD = {
218+
groupBy()(aggregateExprs: _*)
219+
}
220+
207221
/**
208222
* Applies a qualifier to the attributes of this relation. Can be used to disambiguate attributes
209223
* with the same name, for example, when performing self-joins.
@@ -281,7 +295,7 @@ class SchemaRDD(
281295
* supports features such as filter pushdown.
282296
*/
283297
@Experimental
284-
override def count(): Long = groupBy()(Count(Literal(1))).collect().head.getLong(0)
298+
override def count(): Long = aggregate(Count(Literal(1))).collect().head.getLong(0)
285299

286300
/**
287301
* :: Experimental ::

sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,14 @@ class DslQuerySuite extends QueryTest {
3939
testData2.groupBy('a)('a, Sum('b)),
4040
Seq((1,3),(2,3),(3,3))
4141
)
42+
checkAnswer(
43+
testData2.groupBy('a)('a, Sum('b) as 'totB).aggregate(Sum('totB)),
44+
9
45+
)
46+
checkAnswer(
47+
testData2.aggregate(Sum('b)),
48+
9
49+
)
4250
}
4351

4452
test("select *") {

0 commit comments

Comments
 (0)