Skip to content

Commit ae1186f

Browse files
committed
[SPARK-34581][SQL] Don't optimize out grouping expressions from aggregate expressions
1 parent 908318f commit ae1186f

File tree

3 files changed

+48
-3
lines changed

3 files changed

+48
-3
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/grouping.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.sql.catalyst.expressions
1919

2020
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
21+
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodegenFallback, ExprCode}
2222
import org.apache.spark.sql.internal.SQLConf
2323
import org.apache.spark.sql.types._
2424

@@ -155,3 +155,20 @@ object GroupingID {
155155
if (SQLConf.get.integerGroupingIdEnabled) IntegerType else LongType
156156
}
157157
}
158+
159+
/**
160+
* Wrapper expression to avoid further optizations of child
161+
*/
162+
case class GroupingExpression(child: Expression) extends UnaryExpression {
163+
override def eval(input: InternalRow): Any = {
164+
child.eval(input)
165+
}
166+
167+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
168+
child.genCode(ctx)
169+
}
170+
171+
override def dataType: DataType = {
172+
child.dataType
173+
}
174+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -870,8 +870,19 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper {
870870
if (haveCommonNonDeterministicOutput(p.projectList, agg.aggregateExpressions)) {
871871
p
872872
} else {
873-
agg.copy(aggregateExpressions = buildCleanedProjectList(
874-
p.projectList, agg.aggregateExpressions))
873+
val complexGroupingExpressions =
874+
ExpressionSet(agg.groupingExpressions.filter(_.children.nonEmpty))
875+
876+
def wrapGroupingExpression(e: Expression): Expression = e match {
877+
case _: AggregateExpression => e
878+
case _ if complexGroupingExpressions.contains(e) => GroupingExpression(e)
879+
case _ => e.mapChildren(wrapGroupingExpression)
880+
}
881+
882+
val wrappedAggregateExpressions =
883+
agg.aggregateExpressions.map(wrapGroupingExpression(_).asInstanceOf[NamedExpression])
884+
agg.copy(aggregateExpressions =
885+
buildCleanedProjectList(p.projectList, wrappedAggregateExpressions))
875886
}
876887
case Project(l1, g @ GlobalLimit(_, limit @ LocalLimit(_, p2 @ Project(l2, _))))
877888
if isRenaming(l1, l2) =>

sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4116,6 +4116,23 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark
41164116
}
41174117
}
41184118
}
4119+
4120+
test("SPARK-34581: Don't optimize out grouping expressions from aggregate expressions") {
4121+
withTempView("t") {
4122+
Seq[Integer](null, 1, 2, 3, null).toDF("id").createOrReplaceTempView("t")
4123+
4124+
val df = spark.sql(
4125+
"""
4126+
|SELECT not(id), c
4127+
|FROM (
4128+
| SELECT t.id IS NULL AS id, count(*) AS c
4129+
| FROM t
4130+
| GROUP BY t.id IS NULL
4131+
|) t
4132+
|""".stripMargin)
4133+
checkAnswer(df, Row(true, 3) :: Row(false, 2) :: Nil)
4134+
}
4135+
}
41194136
}
41204137

41214138
case class Foo(bar: Option[String])

0 commit comments

Comments
 (0)