Skip to content

[SPARK-11946][SQL] Audit pivot API for 1.6. #9929

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1574,7 +1574,6 @@ class DAGScheduler(
}

def stop() {
logInfo("Stopping DAGScheduler")
messageScheduler.shutdownNow()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is done as part of #9603 (comment) but it is way too small to deserve its own pr.

eventProcessLoop.stop()
taskScheduler.stop()
Expand Down
12 changes: 8 additions & 4 deletions python/pyspark/sql/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,20 +168,24 @@ def sum(self, *cols):
"""

@since(1.6)
def pivot(self, pivot_col, *values):
def pivot(self, pivot_col, values=None):
"""Pivots a column of the current DataFrame and preform the specified aggregation.

:param pivot_col: Column to pivot
:param values: Optional list of values of pivotColumn that will be translated to columns in
the output data frame. If values are not provided the method with do an immediate call
to .distinct() on the pivot column.
>>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect()

>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect()
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)]

>>> df4.groupBy("year").pivot("course").sum("earnings").collect()
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)]
"""
jgd = self._jdf.pivot(_to_java_column(pivot_col),
_to_seq(self.sql_ctx._sc, values, _create_column_from_literal))
if values is None:
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc @davies

jgd = self._jdf.pivot(pivot_col)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use _to_java_column(pivot_col) and _to_seq() here? or df.pivot(df.a) may fail

else:
jgd = self._jdf.pivot(pivot_col, values)
return GroupedData(jgd, self.sql_ctx)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ object Literal {
case a: Array[Byte] => Literal(a, BinaryType)
case i: CalendarInterval => Literal(i, CalendarIntervalType)
case null => Literal(null, NullType)
case v: Literal => v
case _ =>
throw new RuntimeException("Unsupported literal type " + v.getClass + " " + v)
}
Expand Down
154 changes: 88 additions & 66 deletions sql/core/src/main/scala/org/apache/spark/sql/GroupedData.scala
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate._
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate}
import org.apache.spark.sql.types.{StringType, NumericType}
import org.apache.spark.sql.types.NumericType


/**
Expand Down Expand Up @@ -282,74 +282,96 @@ class GroupedData protected[sql](
}

/**
* (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified
* aggregation.
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
* // Or without specifying column values
* df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
* }}}
* @param pivotColumn Column to pivot
* @param values Optional list of values of pivotColumn that will be translated to columns in the
* output data frame. If values are not provided the method with do an immediate
* call to .distinct() on the pivot column.
* @since 1.6.0
*/
@scala.annotation.varargs
def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match {
case _: GroupedData.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case GroupedData.GroupByType =>
val pivotValues = if (values.nonEmpty) {
values.map {
case Column(literal: Literal) => literal
case other =>
throw new UnsupportedOperationException(
s"The values of a pivot must be literals, found $other")
}
} else {
// This is to prevent unintended OOM errors when the number of distinct values is large
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
// Get the distinct values of the column and sort them so its consistent
val values = df.select(pivotColumn)
.distinct()
.sort(pivotColumn)
.map(_.get(0))
.take(maxValues + 1)
.map(Literal(_)).toSeq
if (values.length > maxValues) {
throw new RuntimeException(
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
"this could indicate an error. " +
"If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " +
s"to at least the number of distinct values of the pivot column.")
}
values
}
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues))
case _ =>
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @since 1.6.0
*/
def pivot(pivotColumn: String): GroupedData = {
// This is to prevent unintended OOM errors when the number of distinct values is large
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES)
// Get the distinct values of the column and sort them so its consistent
val values = df.select(pivotColumn)
.distinct()
.sort(pivotColumn)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@aray do you know why we have a "sort" in here?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sort is there to ensure that the output columns are in a consistent logical order.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok thanks - i'm going to add a comment there to explain.

.map(_.get(0))
.take(maxValues + 1)
.toSeq

if (values.length > maxValues) {
throw new AnalysisException(
s"The pivot column $pivotColumn has more than $maxValues distinct values, " +
"this could indicate an error. " +
s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " +
"to at least the number of distinct values of the pivot column.")
}

pivot(pivotColumn, values)
}

/**
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
* // Or without specifying column values
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
* @param pivotColumn Column to pivot
* @param values Optional list of values of pivotColumn that will be translated to columns in the
* output data frame. If values are not provided the method with do an immediate
* call to .distinct() on the pivot column.
* @since 1.6.0
*/
@scala.annotation.varargs
def pivot(pivotColumn: String, values: Any*): GroupedData = {
val resolvedPivotColumn = Column(df.resolve(pivotColumn))
pivot(resolvedPivotColumn, values.map(functions.lit): _*)
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings")
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings")
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = {
groupType match {
case GroupedData.GroupByType =>
new GroupedData(
df,
groupingExprs,
GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply)))
case _: GroupedData.PivotType =>
throw new UnsupportedOperationException("repeated pivots are not supported")
case _ =>
throw new UnsupportedOperationException("pivot is only supported after a groupBy")
}
}

/**
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation.
* There are two versions of pivot function: one that requires the caller to specify the list
* of distinct values to pivot on, and one that does not. The latter is more concise but less
* efficient, because Spark needs to first compute the list of distinct values internally.
*
* {{{
* // Compute the sum of earnings for each year by course with each course as a separate column
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings");
*
* // Or without specifying column values (less efficient)
* df.groupBy("year").pivot("course").sum("earnings");
* }}}
*
* @param pivotColumn Name of the column to pivot.
* @param values List of values that will be translated to columns in the output DataFrame.
* @since 1.6.0
*/
def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = {
pivot(pivotColumn, values.asScala)
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -282,4 +282,20 @@ public void testSampleBy() {
Assert.assertEquals(1, actual[1].getLong(0));
Assert.assertTrue(2 <= actual[1].getLong(1) && actual[1].getLong(1) <= 13);
}

@Test
public void pivot() {
DataFrame df = context.table("courseSales");
Row[] actual = df.groupBy("year")
.pivot("course", Arrays.<Object>asList("dotNET", "Java"))
.agg(sum("earnings")).orderBy("year").collect();

Assert.assertEquals(2012, actual[0].getInt(0));
Assert.assertEquals(15000.0, actual[0].getDouble(1), 0.01);
Assert.assertEquals(20000.0, actual[0].getDouble(2), 0.01);

Assert.assertEquals(2013, actual[1].getInt(0));
Assert.assertEquals(48000.0, actual[1].getDouble(1), 0.01);
Assert.assertEquals(30000.0, actual[1].getDouble(2), 0.01);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,23 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{

test("pivot courses with literals") {
checkAnswer(
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
courseSales.groupBy("year").pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings")),
Row(2012, 15000.0, 20000.0) :: Row(2013, 48000.0, 30000.0) :: Nil
)
}

test("pivot year with literals") {
checkAnswer(
courseSales.groupBy($"course").pivot($"year", lit(2012), lit(2013)).agg(sum($"earnings")),
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot courses with literals and multiple aggregations") {
checkAnswer(
courseSales.groupBy($"year").pivot($"course", lit("dotNET"), lit("Java"))
courseSales.groupBy($"year")
.pivot("course", Seq("dotNET", "Java"))
.agg(sum($"earnings"), avg($"earnings")),
Row(2012, 15000.0, 7500.0, 20000.0, 20000.0) ::
Row(2013, 48000.0, 48000.0, 30000.0, 30000.0) :: Nil
Expand All @@ -49,37 +50,37 @@ class DataFramePivotSuite extends QueryTest with SharedSQLContext{

test("pivot year with string values (cast)") {
checkAnswer(
courseSales.groupBy("course").pivot("year", "2012", "2013").sum("earnings"),
courseSales.groupBy("course").pivot("year", Seq("2012", "2013")).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot year with int values") {
checkAnswer(
courseSales.groupBy("course").pivot("year", 2012, 2013).sum("earnings"),
courseSales.groupBy("course").pivot("year", Seq(2012, 2013)).sum("earnings"),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot courses with no values") {
// Note Java comes before dotNet in sorted order
checkAnswer(
courseSales.groupBy($"year").pivot($"course").agg(sum($"earnings")),
courseSales.groupBy("year").pivot("course").agg(sum($"earnings")),
Row(2012, 20000.0, 15000.0) :: Row(2013, 30000.0, 48000.0) :: Nil
)
}

test("pivot year with no values") {
checkAnswer(
courseSales.groupBy($"course").pivot($"year").agg(sum($"earnings")),
courseSales.groupBy("course").pivot("year").agg(sum($"earnings")),
Row("dotNET", 15000.0, 48000.0) :: Row("Java", 20000.0, 30000.0) :: Nil
)
}

test("pivot max values inforced") {
test("pivot max values enforced") {
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES, 1)
intercept[RuntimeException](
courseSales.groupBy($"year").pivot($"course")
intercept[AnalysisException](
courseSales.groupBy("year").pivot("course")
)
sqlContext.conf.setConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES,
SQLConf.DATAFRAME_PIVOT_MAX_VALUES.defaultValue.get)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self =>
person
salary
complexData
courseSales
}
}

Expand Down