Skip to content

Commit 7572505

Browse files
maryannxuegatorsmile
authored andcommitted
[SPARK-24790][SQL] Allow complex aggregate expressions in Pivot
## What changes were proposed in this pull request? Relax the check to allow complex aggregate expressions, like `ceil(sum(col1))` or `sum(col1) + 1`, which roughly means any aggregate expression that could appear in an Aggregate plan except pandas UDF (due to the fact that it is not supported in pivot yet). ## How was this patch tested? Added 2 tests in pivot.sql Author: maryannxue <maryannxue@apache.org> Closes apache#21753 from maryannxue/pivot-relax-syntax.
1 parent 1138489 commit 7572505

File tree

3 files changed

+62
-14
lines changed

3 files changed

+62
-14
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -509,12 +509,7 @@ class Analyzer(
509509
|| !p.pivotColumn.resolved => p
510510
case Pivot(groupByExprsOpt, pivotColumn, pivotValues, aggregates, child) =>
511511
// Check all aggregate expressions.
512-
aggregates.foreach { e =>
513-
if (!isAggregateExpression(e)) {
514-
throw new AnalysisException(
515-
s"Aggregate expression required for pivot, found '$e'")
516-
}
517-
}
512+
aggregates.foreach(checkValidAggregateExpression)
518513
// Group-by expressions coming from SQL are implicit and need to be deduced.
519514
val groupByExprs = groupByExprsOpt.getOrElse(
520515
(child.outputSet -- aggregates.flatMap(_.references) -- pivotColumn.references).toSeq)
@@ -586,12 +581,17 @@ class Analyzer(
586581
}
587582
}
588583

589-
private def isAggregateExpression(expr: Expression): Boolean = {
590-
expr match {
591-
case Alias(e, _) => isAggregateExpression(e)
592-
case AggregateExpression(_, _, _, _) => true
593-
case _ => false
594-
}
584+
// Support any aggregate expression that can appear in an Aggregate plan except Pandas UDF.
585+
// TODO: Support Pandas UDF.
586+
private def checkValidAggregateExpression(expr: Expression): Unit = expr match {
587+
case _: AggregateExpression => // OK and leave the argument check to CheckAnalysis.
588+
case expr: PythonUDF if PythonUDF.isGroupedAggPandasUDF(expr) =>
589+
failAnalysis("Pandas UDF aggregate expressions are currently not supported in pivot.")
590+
case e: Attribute =>
591+
failAnalysis(
592+
s"Aggregate expression required for pivot, but '${e.sql}' " +
593+
s"did not appear in any aggregate function.")
594+
case e => e.children.foreach(checkValidAggregateExpression)
595595
}
596596
}
597597

sql/core/src/test/resources/sql-tests/inputs/pivot.sql

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,3 +111,21 @@ PIVOT (
111111
sum(earnings)
112112
FOR year IN (2012, 2013)
113113
);
114+
115+
-- pivot with complex aggregate expressions
116+
SELECT * FROM (
117+
SELECT year, course, earnings FROM courseSales
118+
)
119+
PIVOT (
120+
ceil(sum(earnings)), avg(earnings) + 1 as a1
121+
FOR course IN ('dotNET', 'Java')
122+
);
123+
124+
-- pivot with invalid arguments in aggregate expressions
125+
SELECT * FROM (
126+
SELECT year, course, earnings FROM courseSales
127+
)
128+
PIVOT (
129+
sum(avg(earnings))
130+
FOR course IN ('dotNET', 'Java')
131+
);

sql/core/src/test/resources/sql-tests/results/pivot.sql.out

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
-- Automatically generated by SQLQueryTestSuite
2-
-- Number of queries: 13
2+
-- Number of queries: 15
33

44

55
-- !query 0
@@ -176,7 +176,7 @@ PIVOT (
176176
struct<>
177177
-- !query 11 output
178178
org.apache.spark.sql.AnalysisException
179-
Aggregate expression required for pivot, found 'abs(earnings#x)';
179+
Aggregate expression required for pivot, but 'coursesales.`earnings`' did not appear in any aggregate function.;
180180

181181

182182
-- !query 12
@@ -192,3 +192,33 @@ struct<>
192192
-- !query 12 output
193193
org.apache.spark.sql.AnalysisException
194194
cannot resolve '`year`' given input columns: [__auto_generated_subquery_name.course, __auto_generated_subquery_name.earnings]; line 4 pos 0
195+
196+
197+
-- !query 13
198+
SELECT * FROM (
199+
SELECT year, course, earnings FROM courseSales
200+
)
201+
PIVOT (
202+
ceil(sum(earnings)), avg(earnings) + 1 as a1
203+
FOR course IN ('dotNET', 'Java')
204+
)
205+
-- !query 13 schema
206+
struct<year:int,dotNET_CEIL(sum(CAST(earnings AS BIGINT))):bigint,dotNET_a1:double,Java_CEIL(sum(CAST(earnings AS BIGINT))):bigint,Java_a1:double>
207+
-- !query 13 output
208+
2012 15000 7501.0 20000 20001.0
209+
2013 48000 48001.0 30000 30001.0
210+
211+
212+
-- !query 14
213+
SELECT * FROM (
214+
SELECT year, course, earnings FROM courseSales
215+
)
216+
PIVOT (
217+
sum(avg(earnings))
218+
FOR course IN ('dotNET', 'Java')
219+
)
220+
-- !query 14 schema
221+
struct<>
222+
-- !query 14 output
223+
org.apache.spark.sql.AnalysisException
224+
It is not allowed to use an aggregate function in the argument of another aggregate function. Please use the inner aggregate function in a sub-query.;

0 commit comments

Comments
 (0)