@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
20
20
import scala .collection .mutable .ArrayBuffer
21
21
22
22
import org .apache .spark .sql .AnalysisException
23
- import org .apache .spark .sql .catalyst .analysis .{ Analyzer , CleanupAliases }
23
+ import org .apache .spark .sql .catalyst .analysis .CleanupAliases
24
24
import org .apache .spark .sql .catalyst .expressions ._
25
25
import org .apache .spark .sql .catalyst .expressions .SubExprUtils ._
26
26
import org .apache .spark .sql .catalyst .expressions .aggregate ._
@@ -342,11 +342,12 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
342
342
*/
343
343
private def extractCorrelatedScalarSubqueries [E <: Expression ](
344
344
expression : E ,
345
- subqueries : ArrayBuffer [ScalarSubquery ]): E = {
345
+ subqueries : ArrayBuffer [( ScalarSubquery , ExprId ) ]): E = {
346
346
val newExpression = expression transform {
347
347
case s : ScalarSubquery if s.children.nonEmpty =>
348
- subqueries += s
349
- s.plan.output.head
348
+ val newExprId = NamedExpression .newExprId
349
+ subqueries += s -> newExprId
350
+ s.plan.output.head.withExprId(newExprId)
350
351
}
351
352
newExpression.asInstanceOf [E ]
352
353
}
@@ -513,17 +514,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
513
514
*/
514
515
private def constructLeftJoins (
515
516
child : LogicalPlan ,
516
- subqueries : ArrayBuffer [ScalarSubquery ]): (LogicalPlan , Seq [(LogicalPlan , LogicalPlan )]) = {
517
- val rewritePlanMap = ArrayBuffer [(LogicalPlan , LogicalPlan )]()
518
- val newPlan = subqueries.foldLeft(child) {
519
- case (currentChild, ScalarSubquery (query, conditions, _)) =>
517
+ subqueries : ArrayBuffer [(ScalarSubquery , ExprId )]): LogicalPlan = {
518
+ subqueries.foldLeft(child) {
519
+ case (currentChild, (ScalarSubquery (query, conditions, _), newExprId)) =>
520
520
val origOutput = query.output.head
521
521
522
522
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
523
523
if (resultWithZeroTups.isEmpty) {
524
524
// CASE 1: Subquery guaranteed not to have the COUNT bug
525
525
Project (
526
- currentChild.output :+ origOutput,
526
+ currentChild.output :+ Alias ( origOutput, origOutput.name)(exprId = newExprId) ,
527
527
Join (currentChild, query, LeftOuter , conditions.reduceOption(And ), JoinHint .NONE ))
528
528
} else {
529
529
// Subquery might have the COUNT bug. Add appropriate corrections.
@@ -543,23 +543,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
543
543
544
544
if (havingNode.isEmpty) {
545
545
// CASE 2: Subquery with no HAVING clause
546
- val joinPlan = Join (currentChild,
547
- Project (query.output :+ alwaysTrueExpr, query),
548
- LeftOuter , conditions.reduceOption(And ), JoinHint .NONE )
549
-
550
- def buildPlan (exprId : ExprId ): LogicalPlan = {
551
- Project (
552
- currentChild.output :+
553
- Alias (
554
- If (IsNull (alwaysTrueRef),
555
- resultWithZeroTups.get,
556
- aggValRef), origOutput.name)(exprId),
557
- joinPlan)
558
- }
546
+ Project (
547
+ currentChild.output :+
548
+ Alias (
549
+ If (IsNull (alwaysTrueRef),
550
+ resultWithZeroTups.get,
551
+ aggValRef), origOutput.name)(exprId = newExprId),
552
+ Join (currentChild,
553
+ Project (query.output :+ alwaysTrueExpr, query),
554
+ LeftOuter , conditions.reduceOption(And ), JoinHint .NONE ))
559
555
560
- val newPlan = buildPlan(origOutput.exprId)
561
- rewritePlanMap += newPlan -> buildPlan(NamedExpression .newExprId)
562
- newPlan
563
556
} else {
564
557
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
565
558
// Need to modify any operators below the join to pass through all columns
@@ -575,85 +568,66 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
575
568
case op => sys.error(s " Unexpected operator $op in corelated subquery " )
576
569
}
577
570
578
- val joinPlan = Join (currentChild,
579
- Project (subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
580
- LeftOuter , conditions.reduceOption(And ), JoinHint .NONE )
581
-
582
- def buildPlan (exprId : ExprId ): LogicalPlan = {
583
- // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
584
- // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
585
- // ELSE (aggregate value) END AS (original column name)
586
- val caseExpr = Alias (CaseWhen (Seq (
587
- (IsNull (alwaysTrueRef), resultWithZeroTups.get),
588
- (Not (havingNode.get.condition), Literal .create(null , aggValRef.dataType))),
589
- aggValRef),
590
- origOutput.name)(exprId)
591
-
592
- Project (
593
- currentChild.output :+ caseExpr,
594
- joinPlan)
595
- }
571
+ // CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
572
+ // WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
573
+ // ELSE (aggregate value) END AS (original column name)
574
+ val caseExpr = Alias (CaseWhen (Seq (
575
+ (IsNull (alwaysTrueRef), resultWithZeroTups.get),
576
+ (Not (havingNode.get.condition), Literal .create(null , aggValRef.dataType))),
577
+ aggValRef),
578
+ origOutput.name)(exprId = newExprId)
579
+
580
+ Project (
581
+ currentChild.output :+ caseExpr,
582
+ Join (currentChild,
583
+ Project (subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
584
+ LeftOuter , conditions.reduceOption(And ), JoinHint .NONE ))
596
585
597
- val newPlan = buildPlan(origOutput.exprId)
598
- rewritePlanMap += newPlan -> buildPlan(NamedExpression .newExprId)
599
- newPlan
600
586
}
601
587
}
602
588
}
603
-
604
- (newPlan, rewritePlanMap)
605
589
}
606
590
607
591
/**
608
592
* Rewrite [[Filter ]], [[Project ]] and [[Aggregate ]] plans containing correlated scalar
609
593
* subqueries.
610
594
*/
611
- def apply (plan : LogicalPlan ): LogicalPlan = {
612
- val rewritePlanMap = ArrayBuffer [(LogicalPlan , LogicalPlan )]()
613
- val newPlan = plan transform {
614
- case a @ Aggregate (grouping, expressions, child) =>
615
- val subqueries = ArrayBuffer .empty[ScalarSubquery ]
616
- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
617
- if (subqueries.nonEmpty) {
618
- // We currently only allow correlated subqueries in an aggregate if they are part of the
619
- // grouping expressions. As a result we need to replace all the scalar subqueries in the
620
- // grouping expressions by their result.
621
- val newGrouping = grouping.map { e =>
622
- subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
623
- }
624
- val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
625
- rewritePlanMap ++= rewriteMap
626
- Aggregate (newGrouping, newExpressions, newChild)
627
- } else {
628
- a
629
- }
630
- case p @ Project (expressions, child) =>
631
- val subqueries = ArrayBuffer .empty[ScalarSubquery ]
632
- val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
633
- if (subqueries.nonEmpty) {
634
- val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
635
- rewritePlanMap ++= rewriteMap
636
- Project (newExpressions, newChild)
637
- } else {
638
- p
595
+ def apply (plan : LogicalPlan ): LogicalPlan = plan transformUpWithNewOutput {
596
+ case a @ Aggregate (grouping, expressions, child) =>
597
+ val subqueries = ArrayBuffer .empty[(ScalarSubquery , ExprId )]
598
+ val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
599
+ if (subqueries.nonEmpty) {
600
+ // We currently only allow correlated subqueries in an aggregate if they are part of the
601
+ // grouping expressions. As a result we need to replace all the scalar subqueries in the
602
+ // grouping expressions by their result.
603
+ val newGrouping = grouping.map { e =>
604
+ subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
639
605
}
640
- case f @ Filter (condition, child) =>
641
- val subqueries = ArrayBuffer .empty[ScalarSubquery ]
642
- val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
643
- if (subqueries.nonEmpty) {
644
- val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
645
- rewritePlanMap ++= rewriteMap
646
- Project (f.output, Filter (newCondition, newChild))
647
- } else {
648
- f
649
- }
650
- }
651
-
652
- if (rewritePlanMap.nonEmpty) {
653
- assert(! plan.fastEquals(newPlan))
654
- Analyzer .rewritePlan(newPlan, rewritePlanMap.toMap)._1
655
- } else {
656
- newPlan
657
- }
606
+ val newAgg = Aggregate (newGrouping, newExpressions, constructLeftJoins(child, subqueries))
607
+ val attrMapping = a.output.zip(newAgg.output)
608
+ newAgg -> attrMapping
609
+ } else {
610
+ a -> Nil
611
+ }
612
+ case p @ Project (expressions, child) =>
613
+ val subqueries = ArrayBuffer .empty[(ScalarSubquery , ExprId )]
614
+ val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
615
+ if (subqueries.nonEmpty) {
616
+ val newProj = Project (newExpressions, constructLeftJoins(child, subqueries))
617
+ val attrMapping = p.output.zip(newProj.output)
618
+ newProj -> attrMapping
619
+ } else {
620
+ p -> Nil
621
+ }
622
+ case f @ Filter (condition, child) =>
623
+ val subqueries = ArrayBuffer .empty[(ScalarSubquery , ExprId )]
624
+ val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
625
+ if (subqueries.nonEmpty) {
626
+ val newProj = Project (f.output, Filter (newCondition, constructLeftJoins(child, subqueries)))
627
+ val attrMapping = f.output.zip(newProj.output)
628
+ newProj -> attrMapping
629
+ } else {
630
+ f -> Nil
631
+ }
658
632
}
659
633
}
0 commit comments