Skip to content

Commit 2293fd4

Browse files
committed
move logic to the beginning of optimization, simplify test
1 parent 5ab9f75 commit 2293fd4

File tree

3 files changed

+34
-19
lines changed

3 files changed

+34
-19
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,6 +147,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
147147
EliminateView,
148148
ReplaceExpressions,
149149
RewriteNonCorrelatedExists,
150+
WrapGroupingExpressions,
150151
ComputeCurrentTime,
151152
GetCurrentDatabaseAndCatalog(catalogManager)) ::
152153
//////////////////////////////////////////////////////////////////////////////////////////
@@ -870,19 +871,8 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
870871
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
871872
p
872873
} else {
873-
val complexGroupingExpressions =
874-
ExpressionSet(agg.groupingExpressions.filter(_.children.nonEmpty))
875-
876-
def wrapGroupingExpression(e: Expression): Expression = e match {
877-
case _: AggregateExpression => e
878-
case _ if complexGroupingExpressions.contains(e) => GroupingExpression(e)
879-
case _ => e.mapChildren(wrapGroupingExpression)
880-
}
881-
882-
val wrappedAggregateExpressions =
883-
agg.aggregateExpressions.map(wrapGroupingExpression(_).asInstanceOf[NamedExpression])
884-
agg.copy(aggregateExpressions =
885-
buildCleanedProjectList(p.projectList, wrappedAggregateExpressions))
874+
agg.copy(aggregateExpressions = buildCleanedProjectList(
875+
p.projectList, agg.aggregateExpressions))
886876
}
887877
case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
888878
if isRenaming(l1, l2) =>

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,34 @@ object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
6666
}
6767
}
6868

69+
/**
70+
* Wrap complex grouping expression in aggregate expressions without aggregate function into
71+
* `GroupingExpression` nodes so as to avoid further optimizations between the expression and its
72+
* parent.
73+
*
74+
* This is required as further optimizations could change the grouping expression and so make the
75+
* aggregate expression invalid.
76+
*/
77+
object WrapGroupingExpressions extends Rule[LogicalPlan] {
78+
override def apply(plan: LogicalPlan): LogicalPlan = {
79+
plan transform {
80+
case a: Aggregate =>
81+
val complexGroupingExpressions =
82+
ExpressionSet(a.groupingExpressions.filter(_.children.nonEmpty))
83+
84+
def wrapGroupingExpression(e: Expression): Expression = e match {
85+
case _: GroupingExpression => e
86+
case _: AggregateExpression => e
87+
case _ if complexGroupingExpressions.contains(e) => GroupingExpression(e)
88+
case _ => e.mapChildren(wrapGroupingExpression)
89+
}
90+
91+
a.copy(aggregateExpressions =
92+
a.aggregateExpressions.map(wrapGroupingExpression(_).asInstanceOf[NamedExpression]))
93+
}
94+
}
95+
}
96+
6997
/**
7098
* Computes the current date and time to make sure we return the same result in a single query.
7199
*/

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4123,12 +4123,9 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
41234123

41244124
val df = spark.sql(
41254125
"""
4126-
|SELECT not(id), c
4127-
|FROM (
4128-
| SELECT t.id IS NULL AS id, count(*) AS c
4129-
| FROM t
4130-
| GROUP BY t.id IS NULL
4131-
|) t
4126+
|SELECT not(t.id IS NULL), count(*) AS c
4127+
|FROM t
4128+
|GROUP BY t.id IS NULL
41324129
|""".stripMargin)
41334130
checkAnswer(df, Row(true, 3) :: Row(false, 2) :: Nil)
41344131
}

0 commit comments

Comments
 (0)