Skip to content

Commit 3f40af5

Browse files
committed
[SPARK-11946][SQL] Audit pivot API for 1.6.
Currently pivot's signature looks like ```scala scala.annotation.varargs def pivot(pivotColumn: Column, values: Column*): GroupedData scala.annotation.varargs def pivot(pivotColumn: String, values: Any*): GroupedData ``` I think we can remove the one that takes "Column" types, since callers should always be passing in literals. It'd also be more clear if the values are not varargs, but rather Seq or java.util.List. I also made similar changes for Python. Author: Reynold Xin <rxin@databricks.com> Closes #9929 from rxin/SPARK-11946. (cherry picked from commit f315272) Signed-off-by: Reynold Xin <rxin@databricks.com>
1 parent 0419fd3 commit 3f40af5

File tree

7 files changed

+125
-81
lines changed

7 files changed

+125
-81
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1574,7 +1574,6 @@ class DAGScheduler(
15741574
}
15751575

15761576
def stop() {
1577-
logInfo("Stopping DAGScheduler")
15781577
messageScheduler.shutdownNow()
15791578
eventProcessLoop.stop()
15801579
taskScheduler.stop()

python/pyspark/sql/group.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -168,20 +168,24 @@ def sum(self, *cols):
168168
"""
169169

170170
@since(1.6)
171-
def pivot(self, pivot_col, *values):
171+
def pivot(self, pivot_col, values=None):
172172
"""Pivots a column of the current DataFrame and preform the specified aggregation.
173173
174174
:param pivot_col: Column to pivot
175175
:param values: Optional list of values of pivotColumn that will be translated to columns in
176176
the output data frame. If values are not provided the method with do an immediate call
177177
to .distinct() on the pivot column.
178-
>>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect()
178+
179+
>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
179180
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]
181+
180182
>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
181183
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
182184
"""
183-
jgd = self._jdf.pivot(_to_java_column(pivot_col),
184-
_to_seq(self.sql_ctx._sc, values, _create_column_from_literal))
185+
if values is None:
186+
jgd = self._jdf.pivot(pivot_col)
187+
else:
188+
jgd = self._jdf.pivot(pivot_col, values)
185189
return GroupedData(jgd, self.sql_ctx)
186190

187191

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/literals.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ object Literal {
4444
case a: Array[Byte] => Literal(a, BinaryType)
4545
case i: CalendarInterval => Literal(i, CalendarIntervalType)
4646
case null => Literal(null, NullType)
47+
case v: Literal => v
4748
case _ =>
4849
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
4950
}

sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala

Lines changed: 88 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli
2525
import org.apache.spark.sql.catalyst.expressions._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate._
2727
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
28-
import org.apache.spark.sql.types.{StringType, NumericType}
28+
import org.apache.spark.sql.types.NumericType
2929

3030

3131
/**
@@ -282,74 +282,96 @@ class GroupedData protected[sql](
282282
}
283283

284284
/**
285-
* (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
286-
* aggregation.
287-
* {{{
288-
* // Compute the sum of earnings for each year by course with each course as a separate column
289-
* df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
290-
* // Or without specifying column values
291-
* df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
292-
* }}}
293-
* @param pivotColumn Column to pivot
294-
* @param values Optional list of values of pivotColumn that will be translated to columns in the
295-
* output data frame. If values are not provided the method with do an immediate
296-
* call to .distinct() on the pivot column.
297-
* @since 1.6.0
298-
*/
299-
@scala.annotation.varargs
300-
def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
301-
case _: GroupedData.PivotType =>
302-
throw new UnsupportedOperationException("repeated pivots are not supported")
303-
case GroupedData.GroupByType =>
304-
val pivotValues = if (values.nonEmpty) {
305-
values.map {
306-
case Column(literal: Literal) => literal
307-
case other =>
308-
throw new UnsupportedOperationException(
309-
s"The values of a pivot must be literals, found $other")
310-
}
311-
} else {
312-
// This is to prevent unintended OOM errors when the number of distinct values is large
313-
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
314-
// Get the distinct values of the column and sort them so its consistent
315-
val values = df.select(pivotColumn)
316-
.distinct()
317-
.sort(pivotColumn)
318-
.map(_.get(0))
319-
.take(maxValues + 1)
320-
.map(Literal(_)).toSeq
321-
if (values.length > maxValues) {
322-
throw new RuntimeException(
323-
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
324-
"this could indicate an error. " +
325-
"If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
326-
s"to at least the number of distinct values of the pivot column.")
327-
}
328-
values
329-
}
330-
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
331-
case _ =>
332-
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
285+
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
286+
* There are two versions of pivot function: one that requires the caller to specify the list
287+
* of distinct values to pivot on, and one that does not. The latter is more concise but less
288+
* efficient, because Spark needs to first compute the list of distinct values internally.
289+
*
290+
* {{{
291+
* // Compute the sum of earnings for each year by course with each course as a separate column
292+
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
293+
*
294+
* // Or without specifying column values (less efficient)
295+
* df.groupBy("year").pivot("course").sum("earnings")
296+
* }}}
297+
*
298+
* @param pivotColumn Name of the column to pivot.
299+
* @since 1.6.0
300+
*/
301+
def pivot(pivotColumn: String): GroupedData = {
302+
// This is to prevent unintended OOM errors when the number of distinct values is large
303+
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
304+
// Get the distinct values of the column and sort them so its consistent
305+
val values = df.select(pivotColumn)
306+
.distinct()
307+
.sort(pivotColumn)
308+
.map(_.get(0))
309+
.take(maxValues + 1)
310+
.toSeq
311+
312+
if (values.length > maxValues) {
313+
throw new AnalysisException(
314+
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
315+
"this could indicate an error. " +
316+
s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " +
317+
"to at least the number of distinct values of the pivot column.")
318+
}
319+
320+
pivot(pivotColumn, values)
333321
}
334322

335323
/**
336-
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
337-
* {{{
338-
* // Compute the sum of earnings for each year by course with each course as a separate column
339-
* df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
340-
* // Or without specifying column values
341-
* df.groupBy("year").pivot("course").sum("earnings")
342-
* }}}
343-
* @param pivotColumn Column to pivot
344-
* @param values Optional list of values of pivotColumn that will be translated to columns in the
345-
* output data frame. If values are not provided the method with do an immediate
346-
* call to .distinct() on the pivot column.
347-
* @since 1.6.0
348-
*/
349-
@scala.annotation.varargs
350-
def pivot(pivotColumn: String, values: Any*): GroupedData = {
351-
val resolvedPivotColumn = Column(df.resolve(pivotColumn))
352-
pivot(resolvedPivotColumn, values.map(functions.lit): _*)
324+
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
325+
* There are two versions of pivot function: one that requires the caller to specify the list
326+
* of distinct values to pivot on, and one that does not. The latter is more concise but less
327+
* efficient, because Spark needs to first compute the list of distinct values internally.
328+
*
329+
* {{{
330+
* // Compute the sum of earnings for each year by course with each course as a separate column
331+
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
332+
*
333+
* // Or without specifying column values (less efficient)
334+
* df.groupBy("year").pivot("course").sum("earnings")
335+
* }}}
336+
*
337+
* @param pivotColumn Name of the column to pivot.
338+
* @param values List of values that will be translated to columns in the output DataFrame.
339+
* @since 1.6.0
340+
*/
341+
def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = {
342+
groupType match {
343+
case GroupedData.GroupByType =>
344+
new GroupedData(
345+
df,
346+
groupingExprs,
347+
GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
348+
case _: GroupedData.PivotType =>
349+
throw new UnsupportedOperationException("repeated pivots are not supported")
350+
case _ =>
351+
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
352+
}
353+
}
354+
355+
/**
356+
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
357+
* There are two versions of pivot function: one that requires the caller to specify the list
358+
* of distinct values to pivot on, and one that does not. The latter is more concise but less
359+
* efficient, because Spark needs to first compute the list of distinct values internally.
360+
*
361+
* {{{
362+
* // Compute the sum of earnings for each year by course with each course as a separate column
363+
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
364+
*
365+
* // Or without specifying column values (less efficient)
366+
* df.groupBy("year").pivot("course").sum("earnings");
367+
* }}}
368+
*
369+
* @param pivotColumn Name of the column to pivot.
370+
* @param values List of values that will be translated to columns in the output DataFrame.
371+
* @since 1.6.0
372+
*/
373+
def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = {
374+
pivot(pivotColumn, values.asScala)
353375
}
354376
}
355377

sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,4 +282,20 @@ public void testSampleBy() {
282282
Assert.assertEquals(1, actual[1].getLong(0));
283283
Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13);
284284
}
285+
286+
@Test
287+
public void pivot() {
288+
DataFrame df = context.table("courseSales");
289+
Row[] actual = df.groupBy("year")
290+
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
291+
.agg(sum("earnings")).orderBy("year").collect();
292+
293+
Assert.assertEquals(2012, actual[0].getInt(0));
294+
Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
295+
Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01);
296+
297+
Assert.assertEquals(2013, actual[1].getInt(0));
298+
Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
299+
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
300+
}
285301
}

sql/core/src/test/scala/org/apache/spark/sql/DataFramePivotSuite.scala

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,22 +25,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
2525

2626
test("pivot courses with literals") {
2727
checkAnswer(
28-
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
28+
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
2929
.agg(sum($"earnings")),
3030
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
3131
)
3232
}
3333

3434
test("pivot year with literals") {
3535
checkAnswer(
36-
courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
36+
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
3737
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
3838
)
3939
}
4040

4141
test("pivot courses with literals and multiple aggregations") {
4242
checkAnswer(
43-
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
43+
courseSales.groupBy($"year")
44+
.pivot("course", Seq("dotNET", "Java"))
4445
.agg(sum($"earnings"), avg($"earnings")),
4546
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
4647
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
@@ -49,37 +50,37 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{
4950

5051
test("pivot year with string values (cast)") {
5152
checkAnswer(
52-
courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
53+
courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"),
5354
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
5455
)
5556
}
5657

5758
test("pivot year with int values") {
5859
checkAnswer(
59-
courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
60+
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"),
6061
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
6162
)
6263
}
6364

6465
test("pivot courses with no values") {
6566
// Note Java comes before dotNet in sorted order
6667
checkAnswer(
67-
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
68+
courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
6869
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
6970
)
7071
}
7172

7273
test("pivot year with no values") {
7374
checkAnswer(
74-
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
75+
courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
7576
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
7677
)
7778
}
7879

79-
test("pivot max values inforced") {
80+
test("pivot max values enforced") {
8081
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
81-
intercept[RuntimeException](
82-
courseSales.groupBy($"year").pivot($"course")
82+
intercept[AnalysisException](
83+
courseSales.groupBy("year").pivot("course")
8384
)
8485
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
8586
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)

sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self =>
281281
person
282282
salary
283283
complexData
284+
courseSales
284285
}
285286
}
286287

0 commit comments

Comments
 (0)