@@ -338,20 +338,15 @@ object PullupCorrelatedPredicates extends Rule[LogicalPlan] with PredicateHelper
338
338
object RewriteCorrelatedScalarSubquery extends Rule [LogicalPlan ] {
339
339
/**
340
340
* Extract all correlated scalar subqueries from an expression. The subqueries are collected using
341
- * the given collector. To avoid the reuse of `exprId`s, this method generates new `exprId`
342
- * for the subqueries and rewrite references in the given `expression`.
343
- * This method returns extracted subqueries and the corresponding `exprId`s and these values
344
- * will be used later in `constructLeftJoins` for building the child plan that
345
- * returns subquery output with the `exprId`s.
341
+ * the given collector. The expression is rewritten and returned.
346
342
*/
347
343
private def extractCorrelatedScalarSubqueries [E <: Expression ](
348
344
expression : E ,
349
- subqueries : ArrayBuffer [( ScalarSubquery , ExprId ) ]): E = {
345
+ subqueries : ArrayBuffer [ScalarSubquery ]): E = {
350
346
val newExpression = expression transform {
351
347
case s : ScalarSubquery if s.children.nonEmpty =>
352
- val newExprId = NamedExpression .newExprId
353
- subqueries += s -> newExprId
354
- s.plan.output.head.withExprId(newExprId)
348
+ subqueries += s
349
+ s.plan.output.head
355
350
}
356
351
newExpression.asInstanceOf [E ]
357
352
}
@@ -512,19 +507,23 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
512
507
513
508
/**
514
509
* Construct a new child plan by left joining the given subqueries to a base plan.
510
+ * This method returns the child plan and an attribute mapping
511
+ * for the updated `ExprId`s of subqueries. If the non-empty mapping returned,
512
+ * this rule will rewrite subquery references in a parent plan based on it.
515
513
*/
516
514
private def constructLeftJoins (
517
515
child : LogicalPlan ,
518
- subqueries : ArrayBuffer [(ScalarSubquery , ExprId )]): LogicalPlan = {
519
- subqueries.foldLeft(child) {
520
- case (currentChild, (ScalarSubquery (query, conditions, _), newExprId)) =>
516
+ subqueries : ArrayBuffer [ScalarSubquery ]): (LogicalPlan , AttributeMap [Attribute ]) = {
517
+ val subqueryAttrMapping = ArrayBuffer [(Attribute , Attribute )]()
518
+ val newChild = subqueries.foldLeft(child) {
519
+ case (currentChild, ScalarSubquery (query, conditions, _)) =>
521
520
val origOutput = query.output.head
522
521
523
522
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
524
523
if (resultWithZeroTups.isEmpty) {
525
524
// CASE 1: Subquery guaranteed not to have the COUNT bug
526
525
Project (
527
- currentChild.output :+ Alias ( origOutput, origOutput.name)(exprId = newExprId) ,
526
+ currentChild.output :+ origOutput,
528
527
Join (currentChild, query, LeftOuter , conditions.reduceOption(And ), JoinHint .NONE ))
529
528
} else {
530
529
// Subquery might have the COUNT bug. Add appropriate corrections.
@@ -544,12 +543,13 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
544
543
545
544
if (havingNode.isEmpty) {
546
545
// CASE 2: Subquery with no HAVING clause
546
+ val subqueryResultExpr =
547
+ Alias (If (IsNull (alwaysTrueRef),
548
+ resultWithZeroTups.get,
549
+ aggValRef), origOutput.name)()
550
+ subqueryAttrMapping += ((origOutput, subqueryResultExpr.toAttribute))
547
551
Project (
548
- currentChild.output :+
549
- Alias (
550
- If (IsNull (alwaysTrueRef),
551
- resultWithZeroTups.get,
552
- aggValRef), origOutput.name)(exprId = newExprId),
552
+ currentChild.output :+ subqueryResultExpr,
553
553
Join (currentChild,
554
554
Project (query.output :+ alwaysTrueExpr, query),
555
555
LeftOuter , conditions.reduceOption(And ), JoinHint .NONE ))
@@ -576,7 +576,9 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
576
576
(IsNull (alwaysTrueRef), resultWithZeroTups.get),
577
577
(Not (havingNode.get.condition), Literal .create(null , aggValRef.dataType))),
578
578
aggValRef),
579
- origOutput.name)(exprId = newExprId)
579
+ origOutput.name)()
580
+
581
+ subqueryAttrMapping += ((origOutput, caseExpr.toAttribute))
580
582
581
583
Project (
582
584
currentChild.output :+ caseExpr,
@@ -587,6 +589,22 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
587
589
}
588
590
}
589
591
}
592
+ (newChild, AttributeMap (subqueryAttrMapping))
593
+ }
594
+
595
+ private def updateAttrs [E <: Expression ](
596
+ exprs : Seq [E ],
597
+ attrMap : AttributeMap [Attribute ]): Seq [E ] = {
598
+ if (attrMap.nonEmpty) {
599
+ val newExprs = exprs.map { _.transform {
600
+ case a : AttributeReference if attrMap.contains(a) =>
601
+ val exprId = attrMap.getOrElse(a, a).exprId
602
+ a.withExprId(exprId)
603
+ }}
604
+ newExprs.asInstanceOf [Seq [E ]]
605
+ } else {
606
+ exprs
607
+ }
590
608
}
591
609
592
610
/**
@@ -595,36 +613,42 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
595
613
*/
596
614
def apply (plan : LogicalPlan ): LogicalPlan = plan transformUpWithNewOutput {
597
615
case a @ Aggregate (grouping, expressions, child) =>
598
- val subqueries = ArrayBuffer .empty[( ScalarSubquery , ExprId ) ]
599
- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
616
+ val subqueries = ArrayBuffer .empty[ScalarSubquery ]
617
+ val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
600
618
if (subqueries.nonEmpty) {
601
619
// We currently only allow correlated subqueries in an aggregate if they are part of the
602
620
// grouping expressions. As a result we need to replace all the scalar subqueries in the
603
621
// grouping expressions by their result.
604
622
val newGrouping = grouping.map { e =>
605
- subqueries.find(_._1. semanticEquals(e)).map(_._1 .plan.output.head).getOrElse(e)
623
+ subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
606
624
}
607
- val newAgg = Aggregate (newGrouping, newExpressions, constructLeftJoins(child, subqueries))
625
+ val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
626
+ val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
627
+ val newAgg = Aggregate (newGrouping, newExprs, newChild)
608
628
val attrMapping = a.output.zip(newAgg.output)
609
629
newAgg -> attrMapping
610
630
} else {
611
631
a -> Nil
612
632
}
613
633
case p @ Project (expressions, child) =>
614
- val subqueries = ArrayBuffer .empty[( ScalarSubquery , ExprId ) ]
615
- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
634
+ val subqueries = ArrayBuffer .empty[ScalarSubquery ]
635
+ val rewriteExprs = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
616
636
if (subqueries.nonEmpty) {
617
- val newProj = Project (newExpressions, constructLeftJoins(child, subqueries))
637
+ val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
638
+ val newExprs = updateAttrs(rewriteExprs, subqueryAttrMapping)
639
+ val newProj = Project (newExprs, newChild)
618
640
val attrMapping = p.output.zip(newProj.output)
619
641
newProj -> attrMapping
620
642
} else {
621
643
p -> Nil
622
644
}
623
645
case f @ Filter (condition, child) =>
624
- val subqueries = ArrayBuffer .empty[( ScalarSubquery , ExprId ) ]
625
- val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
646
+ val subqueries = ArrayBuffer .empty[ScalarSubquery ]
647
+ val rewriteCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
626
648
if (subqueries.nonEmpty) {
627
- val newProj = Project (f.output, Filter (newCondition, constructLeftJoins(child, subqueries)))
649
+ val (newChild, subqueryAttrMapping) = constructLeftJoins(child, subqueries)
650
+ val newCondition = updateAttrs(Seq (rewriteCondition), subqueryAttrMapping).head
651
+ val newProj = Project (f.output, Filter (newCondition, newChild))
628
652
val attrMapping = f.output.zip(newProj.output)
629
653
newProj -> attrMapping
630
654
} else {
0 commit comments