-
Notifications
You must be signed in to change notification settings - Fork 28.6k
[SPARK-11946][SQL] Audit pivot API for 1.6. #9929
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -168,20 +168,24 @@ def sum(self, *cols): | |
""" | ||
|
||
@since(1.6) | ||
def pivot(self, pivot_col, *values): | ||
def pivot(self, pivot_col, values=None): | ||
"""Pivots a column of the current DataFrame and preform the specified aggregation. | ||
|
||
:param pivot_col: Column to pivot | ||
:param values: Optional list of values of pivotColumn that will be translated to columns in | ||
the output data frame. If values are not provided the method with do an immediate call | ||
to .distinct() on the pivot column. | ||
>>> df4.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings").collect() | ||
|
||
>>> df4.groupBy("year").pivot("course", ["dotNET", "Java"]).sum("earnings").collect() | ||
[Row(year=2012, dotNET=15000, Java=20000), Row(year=2013, dotNET=48000, Java=30000)] | ||
|
||
>>> df4.groupBy("year").pivot("course").sum("earnings").collect() | ||
[Row(year=2012, Java=20000, dotNET=15000), Row(year=2013, Java=30000, dotNET=48000)] | ||
""" | ||
jgd = self._jdf.pivot(_to_java_column(pivot_col), | ||
_to_seq(self.sql_ctx._sc, values, _create_column_from_literal)) | ||
if values is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc @davies |
||
jgd = self._jdf.pivot(pivot_col) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use |
||
else: | ||
jgd = self._jdf.pivot(pivot_col, values) | ||
return GroupedData(jgd, self.sql_ctx) | ||
|
||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{UnresolvedFunction, UnresolvedAli | |
import org.apache.spark.sql.catalyst.expressions._ | ||
import org.apache.spark.sql.catalyst.expressions.aggregate._ | ||
import org.apache.spark.sql.catalyst.plans.logical.{Pivot, Rollup, Cube, Aggregate} | ||
import org.apache.spark.sql.types.{StringType, NumericType} | ||
import org.apache.spark.sql.types.NumericType | ||
|
||
|
||
/** | ||
|
@@ -282,74 +282,96 @@ class GroupedData protected[sql]( | |
} | ||
|
||
/** | ||
* (Scala-specific) Pivots a column of the current [[DataFrame]] and preform the specified | ||
* aggregation. | ||
* {{{ | ||
* // Compute the sum of earnings for each year by course with each course as a separate column | ||
* df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings")) | ||
* // Or without specifying column values | ||
* df.groupBy($"year").pivot($"course").agg(sum($"earnings")) | ||
* }}} | ||
* @param pivotColumn Column to pivot | ||
* @param values Optional list of values of pivotColumn that will be translated to columns in the | ||
* output data frame. If values are not provided the method with do an immediate | ||
* call to .distinct() on the pivot column. | ||
* @since 1.6.0 | ||
*/ | ||
@scala.annotation.varargs | ||
def pivot(pivotColumn: Column, values: Column*): GroupedData = groupType match { | ||
case _: GroupedData.PivotType => | ||
throw new UnsupportedOperationException("repeated pivots are not supported") | ||
case GroupedData.GroupByType => | ||
val pivotValues = if (values.nonEmpty) { | ||
values.map { | ||
case Column(literal: Literal) => literal | ||
case other => | ||
throw new UnsupportedOperationException( | ||
s"The values of a pivot must be literals, found $other") | ||
} | ||
} else { | ||
// This is to prevent unintended OOM errors when the number of distinct values is large | ||
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) | ||
// Get the distinct values of the column and sort them so its consistent | ||
val values = df.select(pivotColumn) | ||
.distinct() | ||
.sort(pivotColumn) | ||
.map(_.get(0)) | ||
.take(maxValues + 1) | ||
.map(Literal(_)).toSeq | ||
if (values.length > maxValues) { | ||
throw new RuntimeException( | ||
s"The pivot column $pivotColumn has more than $maxValues distinct values, " + | ||
"this could indicate an error. " + | ||
"If this was intended, set \"" + SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key + "\" " + | ||
s"to at least the number of distinct values of the pivot column.") | ||
} | ||
values | ||
} | ||
new GroupedData(df, groupingExprs, GroupedData.PivotType(pivotColumn.expr, pivotValues)) | ||
case _ => | ||
throw new UnsupportedOperationException("pivot is only supported after a groupBy") | ||
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation. | ||
* There are two versions of pivot function: one that requires the caller to specify the list | ||
* of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
* efficient, because Spark needs to first compute the list of distinct values internally. | ||
* | ||
* {{{ | ||
* // Compute the sum of earnings for each year by course with each course as a separate column | ||
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") | ||
* | ||
* // Or without specifying column values (less efficient) | ||
* df.groupBy("year").pivot("course").sum("earnings") | ||
* }}} | ||
* | ||
* @param pivotColumn Name of the column to pivot. | ||
* @since 1.6.0 | ||
*/ | ||
def pivot(pivotColumn: String): GroupedData = { | ||
// This is to prevent unintended OOM errors when the number of distinct values is large | ||
val maxValues = df.sqlContext.conf.getConf(SQLConf.DATAFRAME_PIVOT_MAX_VALUES) | ||
// Get the distinct values of the column and sort them so its consistent | ||
val values = df.select(pivotColumn) | ||
.distinct() | ||
.sort(pivotColumn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @aray do you know why we have a "sort" in here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The sort is there to ensure that the output columns are in a consistent logical order. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. ok thanks - i'm going to add a comment there to explain. |
||
.map(_.get(0)) | ||
.take(maxValues + 1) | ||
.toSeq | ||
|
||
if (values.length > maxValues) { | ||
throw new AnalysisException( | ||
s"The pivot column $pivotColumn has more than $maxValues distinct values, " + | ||
"this could indicate an error. " + | ||
s"If this was intended, set ${SQLConf.DATAFRAME_PIVOT_MAX_VALUES.key} " + | ||
"to at least the number of distinct values of the pivot column.") | ||
} | ||
|
||
pivot(pivotColumn, values) | ||
} | ||
|
||
/** | ||
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation. | ||
* {{{ | ||
* // Compute the sum of earnings for each year by course with each course as a separate column | ||
* df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings") | ||
* // Or without specifying column values | ||
* df.groupBy("year").pivot("course").sum("earnings") | ||
* }}} | ||
* @param pivotColumn Column to pivot | ||
* @param values Optional list of values of pivotColumn that will be translated to columns in the | ||
* output data frame. If values are not provided the method with do an immediate | ||
* call to .distinct() on the pivot column. | ||
* @since 1.6.0 | ||
*/ | ||
@scala.annotation.varargs | ||
def pivot(pivotColumn: String, values: Any*): GroupedData = { | ||
val resolvedPivotColumn = Column(df.resolve(pivotColumn)) | ||
pivot(resolvedPivotColumn, values.map(functions.lit): _*) | ||
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation. | ||
* There are two versions of pivot function: one that requires the caller to specify the list | ||
* of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
* efficient, because Spark needs to first compute the list of distinct values internally. | ||
* | ||
* {{{ | ||
* // Compute the sum of earnings for each year by course with each course as a separate column | ||
* df.groupBy("year").pivot("course", Seq("dotNET", "Java")).sum("earnings") | ||
* | ||
* // Or without specifying column values (less efficient) | ||
* df.groupBy("year").pivot("course").sum("earnings") | ||
* }}} | ||
* | ||
* @param pivotColumn Name of the column to pivot. | ||
* @param values List of values that will be translated to columns in the output DataFrame. | ||
* @since 1.6.0 | ||
*/ | ||
def pivot(pivotColumn: String, values: Seq[Any]): GroupedData = { | ||
groupType match { | ||
case GroupedData.GroupByType => | ||
new GroupedData( | ||
df, | ||
groupingExprs, | ||
GroupedData.PivotType(df.resolve(pivotColumn), values.map(Literal.apply))) | ||
case _: GroupedData.PivotType => | ||
throw new UnsupportedOperationException("repeated pivots are not supported") | ||
case _ => | ||
throw new UnsupportedOperationException("pivot is only supported after a groupBy") | ||
} | ||
} | ||
|
||
/** | ||
* Pivots a column of the current [[DataFrame]] and preform the specified aggregation. | ||
* There are two versions of pivot function: one that requires the caller to specify the list | ||
* of distinct values to pivot on, and one that does not. The latter is more concise but less | ||
* efficient, because Spark needs to first compute the list of distinct values internally. | ||
* | ||
* {{{ | ||
* // Compute the sum of earnings for each year by course with each course as a separate column | ||
* df.groupBy("year").pivot("course", Arrays.<Object>asList("dotNET", "Java")).sum("earnings"); | ||
* | ||
* // Or without specifying column values (less efficient) | ||
* df.groupBy("year").pivot("course").sum("earnings"); | ||
* }}} | ||
* | ||
* @param pivotColumn Name of the column to pivot. | ||
* @param values List of values that will be translated to columns in the output DataFrame. | ||
* @since 1.6.0 | ||
*/ | ||
def pivot(pivotColumn: String, values: java.util.List[Any]): GroupedData = { | ||
pivot(pivotColumn, values.asScala) | ||
} | ||
} | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -281,6 +281,7 @@ private[sql] trait SQLTestData { self => | |
person | ||
salary | ||
complexData | ||
courseSales | ||
} | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is done as part of #9603 (comment) but it is way too small to deserve its own pr.