Skip to content

Commit b8ff688

Browse files
arayyhuai
authored andcommitted
[SPARK-8992][SQL] Add pivot to dataframe api
This adds a pivot method to the dataframe api. Following the lead of cube and rollup this adds a Pivot operator that is translated into an Aggregate by the analyzer. Currently the syntax is like: ~~courseSales.pivot(Seq($"year"), $"course", Seq("dotNET", "Java"), sum($"earnings"))~~ ~~Would we be interested in the following syntax also/alternatively? and~~ courseSales.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) //or courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")) Later we can add it to `SQLParser`, but as Hive doesn't support it we cant add it there, right? ~~Also what would be the suggested Java friendly method signature for this?~~ Author: Andrew Ray <ray.andrew@gmail.com> Closes #7841 from aray/sql-pivot.
1 parent 1a21be1 commit b8ff688

File tree

6 files changed

+255
-10
lines changed

6 files changed

+255
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,7 @@ class Analyzer(
7272
ResolveRelations ::
7373
ResolveReferences ::
7474
ResolveGroupingAnalytics ::
75+
ResolvePivot ::
7576
ResolveSortReferences ::
7677
ResolveGenerate ::
7778
ResolveFunctions ::
@@ -166,6 +167,10 @@ class Analyzer(
166167
case g: GroupingAnalytics if g.child.resolved && hasUnresolvedAlias(g.aggregations) =>
167168
g.withNewAggs(assignAliases(g.aggregations))
168169

170+
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child)
171+
if child.resolved && hasUnresolvedAlias(groupByExprs) =>
172+
Pivot(assignAliases(groupByExprs), pivotColumn, pivotValues, aggregates, child)
173+
169174
case Project(projectList, child) if child.resolved && hasUnresolvedAlias(projectList) =>
170175
Project(assignAliases(projectList), child)
171176
}
@@ -248,6 +253,43 @@ class Analyzer(
248253
}
249254
}
250255

256+
object ResolvePivot extends Rule[LogicalPlan] {
257+
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
258+
case p: Pivot if !p.childrenResolved => p
259+
case Pivot(groupByExprs, pivotColumn, pivotValues, aggregates, child) =>
260+
val singleAgg = aggregates.size == 1
261+
val pivotAggregates: Seq[NamedExpression] = pivotValues.flatMap { value =>
262+
def ifExpr(expr: Expression) = {
263+
If(EqualTo(pivotColumn, value), expr, Literal(null))
264+
}
265+
aggregates.map { aggregate =>
266+
val filteredAggregate = aggregate.transformDown {
267+
// Assumption is the aggregate function ignores nulls. This is true for all current
268+
// AggregateFunction's with the exception of First and Last in their default mode
269+
// (which we handle) and possibly some Hive UDAF's.
270+
case First(expr, _) =>
271+
First(ifExpr(expr), Literal(true))
272+
case Last(expr, _) =>
273+
Last(ifExpr(expr), Literal(true))
274+
case a: AggregateFunction =>
275+
a.withNewChildren(a.children.map(ifExpr))
276+
}
277+
if (filteredAggregate.fastEquals(aggregate)) {
278+
throw new AnalysisException(
279+
s"Aggregate expression required for pivot, found '$aggregate'")
280+
}
281+
val name = if (singleAgg) value.toString else value + "_" + aggregate.prettyString
282+
Alias(filteredAggregate, name)()
283+
}
284+
}
285+
val newGroupByExprs = groupByExprs.map {
286+
case UnresolvedAlias(e) => e
287+
case e => e
288+
}
289+
Aggregate(newGroupByExprs, groupByExprs ++ pivotAggregates, child)
290+
}
291+
}
292+
251293
/**
252294
* Replaces [[UnresolvedRelation]]s with concrete relations from the catalog.
253295
*/

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,20 @@ case class Rollup(
386386
this.copy(aggregations = aggs)
387387
}
388388

389+
case class Pivot(
390+
groupByExprs: Seq[NamedExpression],
391+
pivotColumn: Expression,
392+
pivotValues: Seq[Literal],
393+
aggregates: Seq[Expression],
394+
child: LogicalPlan) extends UnaryNode {
395+
override def output: Seq[Attribute] = groupByExprs.map(_.toAttribute) ++ aggregates match {
396+
case agg :: Nil => pivotValues.map(value => AttributeReference(value.toString, agg.dataType)())
397+
case _ => pivotValues.flatMap{ value =>
398+
aggregates.map(agg => AttributeReference(value + "_" + agg.prettyString, agg.dataType)())
399+
}
400+
}
401+
}
402+
389403
case class Limit(limitExpr: Expression, child: LogicalPlan) extends UnaryNode {
390404
override def output: Seq[Attribute] = child.output
391405

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
2424
import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAlias, UnresolvedAttribute, Star}
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate._
27-
import org.apache.spark.sql.catalyst.plans.logical.{Rollup, Cube, Aggregate}
28-
import org.apache.spark.sql.types.NumericType
27+
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
28+
import org.apache.spark.sql.types.{StringType, NumericType}
2929

3030

3131
/**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
5050
aggExprs
5151
}
5252

53-
val aliasedAgg = aggregates.map {
54-
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
55-
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
56-
// make it a NamedExpression.
57-
case u: UnresolvedAttribute => UnresolvedAlias(u)
58-
case expr: NamedExpression => expr
59-
case expr: Expression => Alias(expr, expr.prettyString)()
60-
}
53+
val aliasedAgg = aggregates.map(alias)
54+
6155
groupType match {
6256
case GroupedData.GroupByType =>
6357
DataFrame(
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
6862
case GroupedData.CubeType =>
6963
DataFrame(
7064
df.sqlContext, Cube(groupingExprs, df.logicalPlan, aliasedAgg))
65+
case GroupedData.PivotType(pivotCol, values) =>
66+
val aliasedGrps = groupingExprs.map(alias)
67+
DataFrame(
68+
df.sqlContext, Pivot(aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
7169
}
7270
}
7371

72+
// Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
73+
// will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
74+
// make it a NamedExpression.
75+
private[this] def alias(expr: Expression): NamedExpression = expr match {
76+
case u: UnresolvedAttribute => UnresolvedAlias(u)
77+
case expr: NamedExpression => expr
78+
case expr: Expression => Alias(expr, expr.prettyString)()
79+
}
80+
7481
private[this] def aggregateNumericColumns(colNames: String*)(f: Expression => AggregateFunction)
7582
: DataFrame = {
7683

@@ -273,6 +280,77 @@ class GroupedData protected[sql](
273280
def sum(colNames: String*): DataFrame = {
274281
aggregateNumericColumns(colNames : _*)(Sum)
275282
}
283+
284+
/**
285+
* (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
286+
* aggregation.
287+
* {{{
288+
* // Compute the sum of earnings for each year by course with each course as a separate column
289+
* df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
290+
* // Or without specifying column values
291+
* df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
292+
* }}}
293+
* @param pivotColumn Column to pivot
294+
* @param values Optional list of values of pivotColumn that will be translated to columns in the
295+
* output data frame. If values are not provided the method with do an immediate
296+
* call to .distinct() on the pivot column.
297+
* @since 1.6.0
298+
*/
299+
@scala.annotation.varargs
300+
def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
301+
case _: GroupedData.PivotType =>
302+
throw new UnsupportedOperationException("repeated pivots are not supported")
303+
case GroupedData.GroupByType =>
304+
val pivotValues = if (values.nonEmpty) {
305+
values.map {
306+
case Column(literal: Literal) => literal
307+
case other =>
308+
throw new UnsupportedOperationException(
309+
s"The values of a pivot must be literals, found $other")
310+
}
311+
} else {
312+
// This is to prevent unintended OOM errors when the number of distinct values is large
313+
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
314+
// Get the distinct values of the column and sort them so its consistent
315+
val values = df.select(pivotColumn)
316+
.distinct()
317+
.sort(pivotColumn)
318+
.map(_.get(0))
319+
.take(maxValues + 1)
320+
.map(Literal(_)).toSeq
321+
if (values.length > maxValues) {
322+
throw new RuntimeException(
323+
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
324+
"this could indicate an error. " +
325+
"If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
326+
s"to at least the number of distinct values of the pivot column.")
327+
}
328+
values
329+
}
330+
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
331+
case _ =>
332+
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
333+
}
334+
335+
/**
336+
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
337+
* {{{
338+
* // Compute the sum of earnings for each year by course with each course as a separate column
339+
* df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
340+
* // Or without specifying column values
341+
* df.groupBy("year").pivot("course").sum("earnings")
342+
* }}}
343+
* @param pivotColumn Column to pivot
344+
* @param values Optional list of values of pivotColumn that will be translated to columns in the
345+
* output data frame. If values are not provided the method with do an immediate
346+
* call to .distinct() on the pivot column.
347+
* @since 1.6.0
348+
*/
349+
@scala.annotation.varargs
350+
def pivot(pivotColumn: String, values: Any*): GroupedData = {
351+
val resolvedPivotColumn = Column(df.resolve(pivotColumn))
352+
pivot(resolvedPivotColumn, values.map(functions.lit): _*)
353+
}
276354
}
277355

278356

@@ -307,4 +385,9 @@ private[sql] object GroupedData {
307385
* To indicate it's the ROLLUP
308386
*/
309387
private[sql] object RollupType extends GroupType
388+
389+
/**
390+
* To indicate it's the PIVOT
391+
*/
392+
private[sql] case class PivotType(pivotCol: Expression, values: Seq[Literal]) extends GroupType
310393
}

sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -437,6 +437,13 @@ private[spark] object SQLConf {
437437
defaultValue = Some(true),
438438
isPublic = false)
439439

440+
val DATAFRAME_PIVOT_MAX_VALUES = intConf(
441+
"spark.sql.pivotMaxValues",
442+
defaultValue = Some(10000),
443+
doc = "When doing a pivot without specifying values for the pivot column this is the maximum " +
444+
"number of (distinct) values that will be collected without error."
445+
)
446+
440447
val RUN_SQL_ON_FILES = booleanConf("spark.sql.runSQLOnFiles",
441448
defaultValue = Some(true),
442449
isPublic = false,
Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql
19+
20+
import org.apache.spark.sql.functions._
21+
import org.apache.spark.sql.test.SharedSQLContext
22+
23+
class DataFramePivotSuite extends QueryTest with SharedSQLContext{
24+
import testImplicits._
25+
26+
test("pivot courses with literals") {
27+
checkAnswer(
28+
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
29+
.agg(sum($"earnings")),
30+
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
31+
)
32+
}
33+
34+
test("pivot year with literals") {
35+
checkAnswer(
36+
courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
37+
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
38+
)
39+
}
40+
41+
test("pivot courses with literals and multiple aggregations") {
42+
checkAnswer(
43+
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
44+
.agg(sum($"earnings"), avg($"earnings")),
45+
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
46+
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
47+
)
48+
}
49+
50+
test("pivot year with string values (cast)") {
51+
checkAnswer(
52+
courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
53+
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
54+
)
55+
}
56+
57+
test("pivot year with int values") {
58+
checkAnswer(
59+
courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
60+
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
61+
)
62+
}
63+
64+
test("pivot courses with no values") {
65+
// Note Java comes before dotNet in sorted order
66+
checkAnswer(
67+
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
68+
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
69+
)
70+
}
71+
72+
test("pivot year with no values") {
73+
checkAnswer(
74+
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
75+
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
76+
)
77+
}
78+
79+
test("pivot max values inforced") {
80+
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
81+
intercept[RuntimeException](
82+
courseSales.groupBy($"year").pivot($"course")
83+
)
84+
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
85+
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
86+
}
87+
}

sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,17 @@ private[sql] trait SQLTestData { self =>
242242
df
243243
}
244244

245+
protected lazy val courseSales: DataFrame = {
246+
val df = sqlContext.sparkContext.parallelize(
247+
CourseSales("dotNET", 2012, 10000) ::
248+
CourseSales("Java", 2012, 20000) ::
249+
CourseSales("dotNET", 2012, 5000) ::
250+
CourseSales("dotNET", 2013, 48000) ::
251+
CourseSales("Java", 2013, 30000) :: Nil).toDF()
252+
df.registerTempTable("courseSales")
253+
df
254+
}
255+
245256
/**
246257
* Initialize all test data such that all temp tables are properly registered.
247258
*/
@@ -295,4 +306,5 @@ private[sql] object SQLTestData {
295306
case class Person(id: Int, name: String, age: Int)
296307
case class Salary(personId: Int, salary: Double)
297308
case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
309+
case class CourseSales(course: String, year: Int, earnings: Double)
298310
}

0 commit comments

Comments
 (0)