Skip to content

Commit eb128b0

Browse files
bersprocketscloud-fan
authored andcommitted
[SPARK-50091][SQL] Handle case of aggregates in left-hand operand of IN-subquery
### What changes were proposed in this pull request? This PR adds code to `RewritePredicateSubquery#apply` to explicitly handle the case where an `Aggregate` node contains an aggregate expression in the left-hand operand of an IN-subquery expression. The explicit handler moves the IN-subquery expressions out of the `Aggregate` and into a parent `Project` node. The `Aggregate` will continue to perform the aggregations that were used as an operand to the IN-subquery expression, but will not include the IN-subquery expression itself. After pulling up IN-subquery expressions into a Project node, `RewritePredicateSubquery#apply` is called again to handle the `Project` as a `UnaryNode`. The `Join` will now be inserted between the `Project` and the `Aggregate` node, and the join condition will use an attribute rather than an aggregate expression, e.g.: ``` Project [col1#32, exists#42 AS (sum(col2) IN (listquery()))#40] +- Join ExistenceJoin(exists#42), (sum(col2)#41L = c2#39L) :- Aggregate [col1#32], [col1#32, sum(col2#33) AS sum(col2)#41L] : +- LocalRelation [col1#32, col2#33] +- LocalRelation [c2#39L] ``` `sum(col2)#41L` in the above join condition, despite how it looks, is the name of the attribute, not an aggregate expression. ### Why are the changes needed? The following query fails: ``` create or replace temp view v1(c1, c2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1); create or replace temp view v2(col1, col2) as values (1, 2), (1, 3), (2, 2), (3, 7), (3, 1); select col1, sum(col2) in (select c2 from v1) from v2 group by col1; ``` It fails with this error: ``` [INTERNAL_ERROR] Cannot generate code for expression: sum(input[1, int, false]) SQLSTATE: XX000 ``` With SPARK_TESTING=1, it fails with this error: ``` [PLAN_VALIDATION_FAILED_RULE_IN_BATCH] Rule org.apache.spark.sql.catalyst.optimizer.RewritePredicateSubquery in batch RewriteSubquery generated an invalid plan: Special expressions are placed in the wrong plan: Aggregate [col1#11], [col1#11, first(exists#20, false) AS (sum(col2) IN (listquery()))#19] +- Join ExistenceJoin(exists#20), (sum(col2#12) = c2#18L) :- LocalRelation [col1#11, col2#12] +- LocalRelation [c2#18L] ``` The issue is that `RewritePredicateSubquery` builds a `Join` operator where the join condition contains an aggregate expression. The bug is in the handler for `UnaryNode` in `RewritePredicateSubquery#apply`, which adds a `Join` below the `Aggregate` and assumes that the left-hand operand of IN-subquery can be used in the join condition. This works fine for most cases, but not when the left-hand operand is an aggregate expression. This PR moves the offending IN-subqueries to a `Project` node, with the aggregates replaced by attributes referring to the aggregate expressions. The resulting join condition now uses those attributes rather than the actual aggregate expressions. ### Does this PR introduce _any_ user-facing change? No, other than allowing this type of query to succeed. ### How was this patch tested? New unit tests. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48627 from bersprockets/aggregate_in_set_issue. Authored-by: Bruce Robbins <bersprockets@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com> (cherry picked from commit e02ff1c) Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 696a541 commit eb128b0

File tree

3 files changed

+168
-40
lines changed

3 files changed

+168
-40
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 120 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
2727
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2828
import org.apache.spark.sql.catalyst.expressions.aggregate._
2929
import org.apache.spark.sql.catalyst.optimizer.RewriteCorrelatedScalarSubquery.splitSubquery
30+
import org.apache.spark.sql.catalyst.planning.PhysicalAggregation
3031
import org.apache.spark.sql.catalyst.plans._
3132
import org.apache.spark.sql.catalyst.plans.logical._
3233
import org.apache.spark.sql.catalyst.rules._
@@ -115,6 +116,26 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
115116
}
116117
}
117118

119+
def exprsContainsAggregateInSubquery(exprs: Seq[Expression]): Boolean = {
120+
exprs.exists { expr =>
121+
exprContainsAggregateInSubquery(expr)
122+
}
123+
}
124+
125+
def exprContainsAggregateInSubquery(expr: Expression): Boolean = {
126+
expr.exists {
127+
case InSubquery(values, _) =>
128+
values.exists { v =>
129+
v.exists {
130+
case _: AggregateExpression => true
131+
case _ => false
132+
}
133+
}
134+
case _ => false;
135+
}
136+
}
137+
138+
118139
def apply(plan: LogicalPlan): LogicalPlan = plan.transformWithPruning(
119140
_.containsAnyPattern(EXISTS_SUBQUERY, LIST_SUBQUERY)) {
120141
case Filter(condition, child)
@@ -246,46 +267,106 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
246267
}
247268
}
248269

270+
// Handle the case where the left-hand side of an IN-subquery contains an aggregate.
271+
//
272+
// If an Aggregate node contains such an IN-subquery, this handler will pull up all
273+
// expressions from the Aggregate node into a new Project node. The new Project node
274+
// will then be handled by the Unary node handler.
275+
//
276+
// The Unary node handler uses the left-hand side of the IN-subquery in a
277+
// join condition. Thus, without this pre-transformation, the join condition
278+
// contains an aggregate, which is illegal. With this pre-transformation, the
279+
// join condition contains an attribute from the left-hand side of the
280+
// IN-subquery contained in the Project node.
281+
//
282+
// For example:
283+
//
284+
// SELECT SUM(col2) IN (SELECT c3 FROM v1) AND SUM(col3) > -1 AS x
285+
// FROM v2;
286+
//
287+
// The above query has this plan on entry to RewritePredicateSubquery#apply:
288+
//
289+
// Aggregate [(sum(col2#18) IN (list#12 []) AND (sum(col3#19) > -1)) AS x#13]
290+
// : +- LocalRelation [c3#28L]
291+
// +- LocalRelation [col2#18, col3#19]
292+
//
293+
// Note that the Aggregate node contains the IN-subquery and the left-hand
294+
// side of the IN-subquery is an aggregate expression sum(col2#18)).
295+
//
296+
// This handler transforms the above plan into the following:
297+
// scalastyle:off line.size.limit
298+
//
299+
// Project [(_aggregateexpression#20L IN (list#12 []) AND (_aggregateexpression#21L > -1)) AS x#13]
300+
// : +- LocalRelation [c3#28L]
301+
// +- Aggregate [sum(col2#18) AS _aggregateexpression#20L, sum(col3#19) AS _aggregateexpression#21L]
302+
// +- LocalRelation [col2#18, col3#19]
303+
//
304+
// scalastyle:on
305+
// Note that both the IN-subquery and the greater-than expressions have been
306+
// pulled up into the Project node. These expressions use attributes
307+
// (_aggregateexpression#20L and _aggregateexpression#21L) to refer to the aggregations
308+
// which are still performed in the Aggregate node (sum(col2#18) and sum(col3#19)).
309+
case p @ PhysicalAggregation(
310+
groupingExpressions, aggregateExpressions, resultExpressions, child)
311+
if exprsContainsAggregateInSubquery(p.expressions) =>
312+
val aggExprs = aggregateExpressions.map(
313+
ae => Alias(ae, "_aggregateexpression")(ae.resultId))
314+
val aggExprIds = aggExprs.map(_.exprId).toSet
315+
val resExprs = resultExpressions.map(_.transform {
316+
case a: AttributeReference if aggExprIds.contains(a.exprId) =>
317+
a.withName("_aggregateexpression")
318+
}.asInstanceOf[NamedExpression])
319+
// Rewrite the projection and the aggregate separately and then piece them together.
320+
val newAgg = Aggregate(groupingExpressions, groupingExpressions ++ aggExprs, child)
321+
val newProj = Project(resExprs, newAgg)
322+
handleUnaryNode(newProj)
323+
249324
case u: UnaryNode if u.expressions.exists(
250-
SubqueryExpression.hasInOrCorrelatedExistsSubquery) =>
251-
var newChild = u.child
252-
var introducedAttrs = Seq.empty[Attribute]
253-
val updatedNode = u.mapExpressions(expr => {
254-
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
255-
newChild = p
256-
introducedAttrs ++= newAttrs
257-
// The newExpr can not be None
258-
newExpr.get
259-
}).withNewChildren(Seq(newChild))
260-
updatedNode match {
261-
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
262-
// If we have introduced new `exists`-attributes that are referenced by
263-
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
264-
// first() aggregate function. first() is Spark's executable version of any_value()
265-
// aggregate function.
266-
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
267-
// functions that are not in grouping expressions.
268-
// Note that the same `exists` attr will never appear in groupingExpressions due to
269-
// PullOutGroupingExpressions rule.
270-
// Also note: the value of `exists` is functionally determined by grouping expressions,
271-
// so applying any aggregate function is semantically safe.
272-
val aggFunctionReferences = a.aggregateExpressions.
273-
flatMap(extractAggregateExpressions).
274-
flatMap(_.references).toSet
275-
val nonAggFuncReferences =
276-
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
277-
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)
278-
279-
// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
280-
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
281-
aggExpr.transformUp {
282-
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
283-
new First(attr).toAggregateExpression()
284-
}.asInstanceOf[NamedExpression]
285-
}
286-
a.copy(aggregateExpressions = newAggregateExpressions)
287-
case _ => updatedNode
288-
}
325+
SubqueryExpression.hasInOrCorrelatedExistsSubquery) => handleUnaryNode(u)
326+
}
327+
328+
/**
329+
* Handle the unary node case
330+
*/
331+
private def handleUnaryNode(u: UnaryNode): LogicalPlan = {
332+
var newChild = u.child
333+
var introducedAttrs = Seq.empty[Attribute]
334+
val updatedNode = u.mapExpressions(expr => {
335+
val (newExpr, p, newAttrs) = rewriteExistentialExprWithAttrs(Seq(expr), newChild)
336+
newChild = p
337+
introducedAttrs ++= newAttrs
338+
// The newExpr can not be None
339+
newExpr.get
340+
}).withNewChildren(Seq(newChild))
341+
updatedNode match {
342+
case a: Aggregate if conf.getConf(WRAP_EXISTS_IN_AGGREGATE_FUNCTION) =>
343+
// If we have introduced new `exists`-attributes that are referenced by
344+
// aggregateExpressions within a non-aggregateFunction expression, we wrap them in
345+
// first() aggregate function. first() is Spark's executable version of any_value()
346+
// aggregate function.
347+
// We do this to keep the aggregation valid, i.e avoid references outside of aggregate
348+
// functions that are not in grouping expressions.
349+
// Note that the same `exists` attr will never appear in groupingExpressions due to
350+
// PullOutGroupingExpressions rule.
351+
// Also note: the value of `exists` is functionally determined by grouping expressions,
352+
// so applying any aggregate function is semantically safe.
353+
val aggFunctionReferences = a.aggregateExpressions.
354+
flatMap(extractAggregateExpressions).
355+
flatMap(_.references).toSet
356+
val nonAggFuncReferences =
357+
a.aggregateExpressions.flatMap(_.references).filterNot(aggFunctionReferences.contains)
358+
val toBeWrappedExistsAttrs = introducedAttrs.filter(nonAggFuncReferences.contains)
359+
360+
// Replace all eligible `exists` by `First(exists)` among aggregateExpressions.
361+
val newAggregateExpressions = a.aggregateExpressions.map { aggExpr =>
362+
aggExpr.transformUp {
363+
case attr: Attribute if toBeWrappedExistsAttrs.contains(attr) =>
364+
new First(attr).toAggregateExpression()
365+
}.asInstanceOf[NamedExpression]
366+
}
367+
a.copy(aggregateExpressions = newAggregateExpressions)
368+
case _ => updatedNode
369+
}
289370
}
290371

291372
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteSubquerySuite.scala

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,11 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import org.apache.spark.sql.catalyst.QueryPlanningTracker
2121
import org.apache.spark.sql.catalyst.dsl.expressions._
2222
import org.apache.spark.sql.catalyst.dsl.plans._
23-
import org.apache.spark.sql.catalyst.expressions.{IsNull, ListQuery, Not}
23+
import org.apache.spark.sql.catalyst.expressions.{Cast, IsNull, ListQuery, Not}
2424
import org.apache.spark.sql.catalyst.plans.{ExistenceJoin, LeftSemi, PlanTest}
2525
import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan}
2626
import org.apache.spark.sql.catalyst.rules.RuleExecutor
27+
import org.apache.spark.sql.types.LongType
2728

2829

2930
class RewriteSubquerySuite extends PlanTest {
@@ -79,4 +80,20 @@ class RewriteSubquerySuite extends PlanTest {
7980
Optimize.executeAndTrack(query.analyze, tracker)
8081
assert(tracker.rules(RewritePredicateSubquery.ruleName).numEffectiveInvocations == 0)
8182
}
83+
84+
test("SPARK-50091: Don't put aggregate expression in join condition") {
85+
val relation1 = LocalRelation($"c1".int, $"c2".int, $"c3".int)
86+
val relation2 = LocalRelation($"col1".int, $"col2".int, $"col3".int)
87+
val plan = relation2.groupBy()(sum($"col2").in(ListQuery(relation1.select($"c3"))))
88+
val optimized = Optimize.execute(plan.analyze)
89+
val aggregate = relation2
90+
.select($"col2")
91+
.groupBy()(sum($"col2").as("_aggregateexpression"))
92+
val correctAnswer = aggregate
93+
.join(relation1.select(Cast($"c3", LongType).as("c3")),
94+
ExistenceJoin($"exists".boolean.withNullability(false)),
95+
Some($"_aggregateexpression" === $"c3"))
96+
.select($"exists".as("(sum(col2) IN (listquery()))")).analyze
97+
comparePlans(optimized, correctAnswer)
98+
}
8299
}

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,4 +2800,34 @@ class SubquerySuite extends QueryTest
28002800
checkAnswer(df3, Row(7))
28012801
}
28022802
}
2803+
2804+
test("SPARK-50091: Handle aggregates in left-hand operand of IN-subquery") {
2805+
withView("v1", "v2") {
2806+
Seq((1, 2, 2), (1, 5, 3), (2, 0, 4), (3, 7, 7), (3, 8, 8))
2807+
.toDF("c1", "c2", "c3")
2808+
.createOrReplaceTempView("v1")
2809+
Seq((1, 2, 2), (1, 3, 3), (2, 2, 4), (3, 7, 7), (3, 1, 1))
2810+
.toDF("col1", "col2", "col3")
2811+
.createOrReplaceTempView("v2")
2812+
2813+
val df1 = sql("SELECT col1, SUM(col2) IN (SELECT c3 FROM v1) FROM v2 GROUP BY col1")
2814+
checkAnswer(df1,
2815+
Row(1, false) :: Row(2, true) :: Row(3, true) :: Nil)
2816+
2817+
val df2 = sql("""SELECT
2818+
| col1,
2819+
| SUM(col2) IN (SELECT c3 FROM v1) and SUM(col3) IN (SELECT c2 FROM v1) AS x
2820+
|FROM v2 GROUP BY col1
2821+
|ORDER BY col1""".stripMargin)
2822+
checkAnswer(df2,
2823+
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
2824+
2825+
val df3 = sql("""SELECT col1, (SUM(col2), SUM(col3)) IN (SELECT c3, c2 FROM v1) AS x
2826+
|FROM v2
2827+
|GROUP BY col1
2828+
|ORDER BY col1""".stripMargin)
2829+
checkAnswer(df3,
2830+
Row(1, false) :: Row(2, false) :: Row(3, true) :: Nil)
2831+
}
2832+
}
28032833
}

0 commit comments

Comments
 (0)