Skip to content

Commit 46a227e

Browse files
committed
Fix
1 parent 3a299aa commit 46a227e

File tree

1 file changed

+53
-29
lines changed
  • sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer

1 file changed

+53
-29
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 53 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -338,20 +338,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
338338
object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
339339
/**
340340
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
341-
* the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
342-
* for the subqueries and rewrite references in the given `expression`.
343-
* This method returns extracted subqueries and the corresponding `exprId`s and these values
344-
* will be used later in `constructLeftJoins` for building the child plan that
345-
* returns subquery output with the `exprId`s.
341+
* the given collector. The expression is rewritten and returned.
346342
*/
347343
private def extractCorrelatedScalarSubqueries[E <: Expression](
348344
expression: E,
349-
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
345+
subqueries: ArrayBuffer[ScalarSubquery]): E = {
350346
val newExpression = expression transform {
351347
case s: ScalarSubquery if s.children.nonEmpty =>
352-
val newExprId = NamedExpression.newExprId
353-
subqueries += s -> newExprId
354-
s.plan.output.head.withExprId(newExprId)
348+
subqueries += s
349+
s.plan.output.head
355350
}
356351
newExpression.asInstanceOf[E]
357352
}
@@ -512,19 +507,23 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
512507

513508
/**
514509
* Construct a new child plan by left joining the given subqueries to a base plan.
510+
* This method returns the child plan and an attribute mapping
511+
* for the updated `ExprId`s of subqueries. If the non-empty mapping returned,
512+
* this rule will rewrite subquery references in a parent plan based on it.
515513
*/
516514
private def constructLeftJoins(
517515
child: LogicalPlan,
518-
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
519-
subqueries.foldLeft(child) {
520-
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
516+
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, AttributeMap[Attribute]) = {
517+
val subqueryAttrMapping = ArrayBuffer[(Attribute, Attribute)]()
518+
val newChild = subqueries.foldLeft(child) {
519+
case (currentChild, ScalarSubquery(query, conditions, _)) =>
521520
val origOutput = query.output.head
522521

523522
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
524523
if (resultWithZeroTups.isEmpty) {
525524
// CASE 1: Subquery guaranteed not to have the COUNT bug
526525
Project(
527-
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
526+
currentChild.output :+ origOutput,
528527
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
529528
} else {
530529
// Subquery might have the COUNT bug. Add appropriate corrections.
@@ -544,12 +543,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
544543

545544
if (havingNode.isEmpty) {
546545
// CASE 2: Subquery with no HAVING clause
546+
val subqueryResultExpr =
547+
Alias(If(IsNull(alwaysTrueRef),
548+
resultWithZeroTups.get,
549+
aggValRef), origOutput.name)()
550+
subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute))
547551
Project(
548-
currentChild.output :+
549-
Alias(
550-
If(IsNull(alwaysTrueRef),
551-
resultWithZeroTups.get,
552-
aggValRef), origOutput.name)(exprId = newExprId),
552+
currentChild.output :+ subqueryResultExpr,
553553
Join(currentChild,
554554
Project(query.output :+ alwaysTrueExpr, query),
555555
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
@@ -576,7 +576,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
576576
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
577577
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
578578
aggValRef),
579-
origOutput.name)(exprId = newExprId)
579+
origOutput.name)()
580+
581+
subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))
580582

581583
Project(
582584
currentChild.output :+ caseExpr,
@@ -587,6 +589,22 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
587589
}
588590
}
589591
}
592+
(newChild, AttributeMap(subqueryAttrMapping))
593+
}
594+
595+
private def updateAttrs[E <: Expression](
596+
exprs: Seq[E],
597+
attrMap: AttributeMap[Attribute]): Seq[E] = {
598+
if (attrMap.nonEmpty) {
599+
val newExprs = exprs.map { _.transform {
600+
case a: AttributeReference if attrMap.contains(a) =>
601+
val exprId = attrMap.getOrElse(a, a).exprId
602+
a.withExprId(exprId)
603+
}}
604+
newExprs.asInstanceOf[Seq[E]]
605+
} else {
606+
exprs
607+
}
590608
}
591609

592610
/**
@@ -595,36 +613,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
595613
*/
596614
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
597615
case a @ Aggregate(grouping, expressions, child) =>
598-
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
599-
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
616+
val subqueries = ArrayBuffer.empty[ScalarSubquery]
617+
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
600618
if (subqueries.nonEmpty) {
601619
// We currently only allow correlated subqueries in an aggregate if they are part of the
602620
// grouping expressions. As a result we need to replace all the scalar subqueries in the
603621
// grouping expressions by their result.
604622
val newGrouping = grouping.map { e =>
605-
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
623+
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
606624
}
607-
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
625+
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
626+
val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
627+
val newAgg = Aggregate(newGrouping, newExprs, newChild)
608628
val attrMapping = a.output.zip(newAgg.output)
609629
newAgg -> attrMapping
610630
} else {
611631
a -> Nil
612632
}
613633
case p @ Project(expressions, child) =>
614-
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
615-
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
634+
val subqueries = ArrayBuffer.empty[ScalarSubquery]
635+
val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
616636
if (subqueries.nonEmpty) {
617-
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
637+
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
638+
val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
639+
val newProj = Project(newExprs, newChild)
618640
val attrMapping = p.output.zip(newProj.output)
619641
newProj -> attrMapping
620642
} else {
621643
p -> Nil
622644
}
623645
case f @ Filter(condition, child) =>
624-
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
625-
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
646+
val subqueries = ArrayBuffer.empty[ScalarSubquery]
647+
val rewriteCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
626648
if (subqueries.nonEmpty) {
627-
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
649+
val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
650+
val newCondition = updateAttrs(Seq(rewriteCondition), subqueryAttrMapping).head
651+
val newProj = Project(f.output, Filter(newCondition, newChild))
628652
val attrMapping = f.output.zip(newProj.output)
629653
newProj -> attrMapping
630654
} else {

0 commit comments

Comments
 (0)