@@ -24,8 +24,8 @@ import org.apache.spark.annotation.Experimental
24
24
import org .apache .spark .sql .catalyst .analysis .{UnresolvedFunction , UnresolvedAlias , UnresolvedAttribute , Star }
25
25
import org .apache .spark .sql .catalyst .expressions ._
26
26
import org .apache .spark .sql .catalyst .expressions .aggregate ._
27
- import org .apache .spark .sql .catalyst .plans .logical .{Rollup , Cube , Aggregate }
28
- import org .apache .spark .sql .types .NumericType
27
+ import org .apache .spark .sql .catalyst .plans .logical .{Pivot , Rollup , Cube , Aggregate }
28
+ import org .apache .spark .sql .types .{ StringType , NumericType }
29
29
30
30
31
31
/**
@@ -50,14 +50,8 @@ class GroupedData protected[sql](
50
50
aggExprs
51
51
}
52
52
53
- val aliasedAgg = aggregates.map {
54
- // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
55
- // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
56
- // make it a NamedExpression.
57
- case u : UnresolvedAttribute => UnresolvedAlias (u)
58
- case expr : NamedExpression => expr
59
- case expr : Expression => Alias (expr, expr.prettyString)()
60
- }
53
+ val aliasedAgg = aggregates.map(alias)
54
+
61
55
groupType match {
62
56
case GroupedData .GroupByType =>
63
57
DataFrame (
@@ -68,9 +62,22 @@ class GroupedData protected[sql](
68
62
case GroupedData .CubeType =>
69
63
DataFrame (
70
64
df.sqlContext, Cube (groupingExprs, df.logicalPlan, aliasedAgg))
65
+ case GroupedData .PivotType (pivotCol, values) =>
66
+ val aliasedGrps = groupingExprs.map(alias)
67
+ DataFrame (
68
+ df.sqlContext, Pivot (aliasedGrps, pivotCol, values, aggExprs, df.logicalPlan))
71
69
}
72
70
}
73
71
72
+ // Wrap UnresolvedAttribute with UnresolvedAlias, as when we resolve UnresolvedAttribute, we
73
+ // will remove intermediate Alias for ExtractValue chain, and we need to alias it again to
74
+ // make it a NamedExpression.
75
+ private [this ] def alias (expr : Expression ): NamedExpression = expr match {
76
+ case u : UnresolvedAttribute => UnresolvedAlias (u)
77
+ case expr : NamedExpression => expr
78
+ case expr : Expression => Alias (expr, expr.prettyString)()
79
+ }
80
+
74
81
private [this ] def aggregateNumericColumns (colNames : String * )(f : Expression => AggregateFunction )
75
82
: DataFrame = {
76
83
@@ -273,6 +280,77 @@ class GroupedData protected[sql](
273
280
def sum (colNames : String * ): DataFrame = {
274
281
aggregateNumericColumns(colNames : _* )(Sum )
275
282
}
283
+
284
+ /**
285
+ * (Scala-specific) Pivots a column of the current [[DataFrame ]] and preform the specified
286
+ * aggregation.
287
+ * {{{
288
+ * // Compute the sum of earnings for each year by course with each course as a separate column
289
+ * df.groupBy($"year").pivot($"course", "dotNET", "Java").agg(sum($"earnings"))
290
+ * // Or without specifying column values
291
+ * df.groupBy($"year").pivot($"course").agg(sum($"earnings"))
292
+ * }}}
293
+ * @param pivotColumn Column to pivot
294
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
295
+ * output data frame. If values are not provided the method with do an immediate
296
+ * call to .distinct() on the pivot column.
297
+ * @since 1.6.0
298
+ */
299
+ @ scala.annotation.varargs
300
+ def pivot (pivotColumn : Column , values : Column * ): GroupedData = groupType match {
301
+ case _ : GroupedData .PivotType =>
302
+ throw new UnsupportedOperationException (" repeated pivots are not supported" )
303
+ case GroupedData .GroupByType =>
304
+ val pivotValues = if (values.nonEmpty) {
305
+ values.map {
306
+ case Column (literal : Literal ) => literal
307
+ case other =>
308
+ throw new UnsupportedOperationException (
309
+ s " The values of a pivot must be literals, found $other" )
310
+ }
311
+ } else {
312
+ // This is to prevent unintended OOM errors when the number of distinct values is large
313
+ val maxValues = df.sqlContext.conf.getConf(SQLConf .DATAFRAME_PIVOT_MAX_VALUES )
314
+ // Get the distinct values of the column and sort them so its consistent
315
+ val values = df.select(pivotColumn)
316
+ .distinct()
317
+ .sort(pivotColumn)
318
+ .map(_.get(0 ))
319
+ .take(maxValues + 1 )
320
+ .map(Literal (_)).toSeq
321
+ if (values.length > maxValues) {
322
+ throw new RuntimeException (
323
+ s " The pivot column $pivotColumn has more than $maxValues distinct values, " +
324
+ " this could indicate an error. " +
325
+ " If this was intended, set \" " + SQLConf .DATAFRAME_PIVOT_MAX_VALUES .key + " \" " +
326
+ s " to at least the number of distinct values of the pivot column. " )
327
+ }
328
+ values
329
+ }
330
+ new GroupedData (df, groupingExprs, GroupedData .PivotType (pivotColumn.expr, pivotValues))
331
+ case _ =>
332
+ throw new UnsupportedOperationException (" pivot is only supported after a groupBy" )
333
+ }
334
+
335
+ /**
336
+ * Pivots a column of the current [[DataFrame ]] and preform the specified aggregation.
337
+ * {{{
338
+ * // Compute the sum of earnings for each year by course with each course as a separate column
339
+ * df.groupBy("year").pivot("course", "dotNET", "Java").sum("earnings")
340
+ * // Or without specifying column values
341
+ * df.groupBy("year").pivot("course").sum("earnings")
342
+ * }}}
343
+ * @param pivotColumn Column to pivot
344
+ * @param values Optional list of values of pivotColumn that will be translated to columns in the
345
+ * output data frame. If values are not provided the method with do an immediate
346
+ * call to .distinct() on the pivot column.
347
+ * @since 1.6.0
348
+ */
349
+ @ scala.annotation.varargs
350
+ def pivot (pivotColumn : String , values : Any * ): GroupedData = {
351
+ val resolvedPivotColumn = Column (df.resolve(pivotColumn))
352
+ pivot(resolvedPivotColumn, values.map(functions.lit): _* )
353
+ }
276
354
}
277
355
278
356
@@ -307,4 +385,9 @@ private[sql] object GroupedData {
307
385
* To indicate it's the ROLLUP
308
386
*/
309
387
private [sql] object RollupType extends GroupType
388
+
389
+ /**
390
+ * To indicate it's the PIVOT
391
+ */
392
+ private [sql] case class PivotType (pivotCol : Expression , values : Seq [Literal ]) extends GroupType
310
393
}
0 commit comments