Skip to content

Commit 3f21635

Browse files
committed
Fix
1 parent b3524b6 commit 3f21635

File tree

2 files changed

+87
-124
lines changed

2 files changed

+87
-124
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 19 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1580,36 +1580,25 @@ object ReplaceDistinctWithAggregate extends Rule[LogicalPlan] {
15801580
* Replaces logical [[Deduplicate]] operator with an [[Aggregate]] operator.
15811581
*/
15821582
object ReplaceDeduplicateWithAggregate extends Rule[LogicalPlan] {
1583-
def apply(plan: LogicalPlan): LogicalPlan = {
1584-
val rewritePlanMap = mutable.ArrayBuffer[(LogicalPlan, LogicalPlan)]()
1585-
val newPlan = plan transform {
1586-
case Deduplicate(keys, child) if !child.isStreaming =>
1587-
val keyExprIds = keys.map(_.exprId)
1588-
val aggCols = child.output.map { attr =>
1589-
if (keyExprIds.contains(attr.exprId)) {
1590-
attr -> attr
1591-
} else {
1592-
val alias = Alias(new First(attr).toAggregateExpression(), attr.name)(attr.exprId)
1593-
alias -> alias.newInstance()
1594-
}
1595-
}.unzip
1596-
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
1597-
// aggregations by checking the number of grouping keys. The key difference here is that a
1598-
// global aggregation always returns at least one row even if there are no input rows. Here
1599-
// we append a literal when the grouping key list is empty so that the result aggregate
1600-
// operator is properly treated as a grouping aggregation.
1601-
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
1602-
val newAgg = Aggregate(nonemptyKeys, aggCols._1, child)
1603-
rewritePlanMap += newAgg -> Aggregate(nonemptyKeys, aggCols._2, child)
1604-
newAgg
1605-
}
1606-
1607-
if (rewritePlanMap.nonEmpty) {
1608-
assert(!plan.fastEquals(newPlan))
1609-
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
1610-
} else {
1611-
plan
1612-
}
1583+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
1584+
case d @ Deduplicate(keys, child) if !child.isStreaming =>
1585+
val keyExprIds = keys.map(_.exprId)
1586+
val aggCols = child.output.map { attr =>
1587+
if (keyExprIds.contains(attr.exprId)) {
1588+
attr
1589+
} else {
1590+
Alias(new First(attr).toAggregateExpression(), attr.name)()
1591+
}
1592+
}
1593+
// SPARK-22951: Physical aggregate operators distinguishes global aggregation and grouping
1594+
// aggregations by checking the number of grouping keys. The key difference here is that a
1595+
// global aggregation always returns at least one row even if there are no input rows. Here
1596+
// we append a literal when the grouping key list is empty so that the result aggregate
1597+
// operator is properly treated as a grouping aggregation.
1598+
val nonemptyKeys = if (keys.isEmpty) Literal(1) :: Nil else keys
1599+
val newAgg = Aggregate(nonemptyKeys, aggCols, child)
1600+
val attrMapping = d.output.zip(newAgg.output)
1601+
newAgg -> attrMapping
16131602
}
16141603
}
16151604

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 68 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package org.apache.spark.sql.catalyst.optimizer
2020
import scala.collection.mutable.ArrayBuffer
2121

2222
import org.apache.spark.sql.AnalysisException
23-
import org.apache.spark.sql.catalyst.analysis.{Analyzer, CleanupAliases}
23+
import org.apache.spark.sql.catalyst.analysis.CleanupAliases
2424
import org.apache.spark.sql.catalyst.expressions._
2525
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
2626
import org.apache.spark.sql.catalyst.expressions.aggregate._
@@ -342,11 +342,12 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
342342
*/
343343
private def extractCorrelatedScalarSubqueries[E <: Expression](
344344
expression: E,
345-
subqueries: ArrayBuffer[ScalarSubquery]): E = {
345+
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): E = {
346346
val newExpression = expression transform {
347347
case s: ScalarSubquery if s.children.nonEmpty =>
348-
subqueries += s
349-
s.plan.output.head
348+
val newExprId = NamedExpression.newExprId
349+
subqueries += s -> newExprId
350+
s.plan.output.head.withExprId(newExprId)
350351
}
351352
newExpression.asInstanceOf[E]
352353
}
@@ -513,17 +514,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
513514
*/
514515
private def constructLeftJoins(
515516
child: LogicalPlan,
516-
subqueries: ArrayBuffer[ScalarSubquery]): (LogicalPlan, Seq[(LogicalPlan, LogicalPlan)]) = {
517-
val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]()
518-
val newPlan = subqueries.foldLeft(child) {
519-
case (currentChild, ScalarSubquery(query, conditions, _)) =>
517+
subqueries: ArrayBuffer[(ScalarSubquery, ExprId)]): LogicalPlan = {
518+
subqueries.foldLeft(child) {
519+
case (currentChild, (ScalarSubquery(query, conditions, _), newExprId)) =>
520520
val origOutput = query.output.head
521521

522522
val resultWithZeroTups = evalSubqueryOnZeroTups(query)
523523
if (resultWithZeroTups.isEmpty) {
524524
// CASE 1: Subquery guaranteed not to have the COUNT bug
525525
Project(
526-
currentChild.output :+ origOutput,
526+
currentChild.output :+ Alias(origOutput, origOutput.name)(exprId = newExprId),
527527
Join(currentChild, query, LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
528528
} else {
529529
// Subquery might have the COUNT bug. Add appropriate corrections.
@@ -543,23 +543,16 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
543543

544544
if (havingNode.isEmpty) {
545545
// CASE 2: Subquery with no HAVING clause
546-
val joinPlan = Join(currentChild,
547-
Project(query.output :+ alwaysTrueExpr, query),
548-
LeftOuter, conditions.reduceOption(And), JoinHint.NONE)
549-
550-
def buildPlan(exprId: ExprId): LogicalPlan = {
551-
Project(
552-
currentChild.output :+
553-
Alias(
554-
If(IsNull(alwaysTrueRef),
555-
resultWithZeroTups.get,
556-
aggValRef), origOutput.name)(exprId),
557-
joinPlan)
558-
}
546+
Project(
547+
currentChild.output :+
548+
Alias(
549+
If(IsNull(alwaysTrueRef),
550+
resultWithZeroTups.get,
551+
aggValRef), origOutput.name)(exprId = newExprId),
552+
Join(currentChild,
553+
Project(query.output :+ alwaysTrueExpr, query),
554+
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
559555

560-
val newPlan = buildPlan(origOutput.exprId)
561-
rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId)
562-
newPlan
563556
} else {
564557
// CASE 3: Subquery with HAVING clause. Pull the HAVING clause above the join.
565558
// Need to modify any operators below the join to pass through all columns
@@ -575,85 +568,66 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
575568
case op => sys.error(s"Unexpected operator $op in corelated subquery")
576569
}
577570

578-
val joinPlan = Join(currentChild,
579-
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
580-
LeftOuter, conditions.reduceOption(And), JoinHint.NONE)
581-
582-
def buildPlan(exprId: ExprId): LogicalPlan = {
583-
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
584-
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
585-
// ELSE (aggregate value) END AS (original column name)
586-
val caseExpr = Alias(CaseWhen(Seq(
587-
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
588-
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
589-
aggValRef),
590-
origOutput.name)(exprId)
591-
592-
Project(
593-
currentChild.output :+ caseExpr,
594-
joinPlan)
595-
}
571+
// CASE WHEN alwaysTrue IS NULL THEN resultOnZeroTups
572+
// WHEN NOT (original HAVING clause expr) THEN CAST(null AS <type of aggVal>)
573+
// ELSE (aggregate value) END AS (original column name)
574+
val caseExpr = Alias(CaseWhen(Seq(
575+
(IsNull(alwaysTrueRef), resultWithZeroTups.get),
576+
(Not(havingNode.get.condition), Literal.create(null, aggValRef.dataType))),
577+
aggValRef),
578+
origOutput.name)(exprId = newExprId)
579+
580+
Project(
581+
currentChild.output :+ caseExpr,
582+
Join(currentChild,
583+
Project(subqueryRoot.output :+ alwaysTrueExpr, subqueryRoot),
584+
LeftOuter, conditions.reduceOption(And), JoinHint.NONE))
596585

597-
val newPlan = buildPlan(origOutput.exprId)
598-
rewritePlanMap += newPlan -> buildPlan(NamedExpression.newExprId)
599-
newPlan
600586
}
601587
}
602588
}
603-
604-
(newPlan, rewritePlanMap)
605589
}
606590

607591
/**
608592
* Rewrite [[Filter]], [[Project]] and [[Aggregate]] plans containing correlated scalar
609593
* subqueries.
610594
*/
611-
def apply(plan: LogicalPlan): LogicalPlan = {
612-
val rewritePlanMap = ArrayBuffer[(LogicalPlan, LogicalPlan)]()
613-
val newPlan = plan transform {
614-
case a @ Aggregate(grouping, expressions, child) =>
615-
val subqueries = ArrayBuffer.empty[ScalarSubquery]
616-
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
617-
if (subqueries.nonEmpty) {
618-
// We currently only allow correlated subqueries in an aggregate if they are part of the
619-
// grouping expressions. As a result we need to replace all the scalar subqueries in the
620-
// grouping expressions by their result.
621-
val newGrouping = grouping.map { e =>
622-
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
623-
}
624-
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
625-
rewritePlanMap ++= rewriteMap
626-
Aggregate(newGrouping, newExpressions, newChild)
627-
} else {
628-
a
629-
}
630-
case p @ Project(expressions, child) =>
631-
val subqueries = ArrayBuffer.empty[ScalarSubquery]
632-
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
633-
if (subqueries.nonEmpty) {
634-
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
635-
rewritePlanMap ++= rewriteMap
636-
Project(newExpressions, newChild)
637-
} else {
638-
p
595+
def apply(plan: LogicalPlan): LogicalPlan = plan transformUpWithNewOutput {
596+
case a @ Aggregate(grouping, expressions, child) =>
597+
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
598+
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
599+
if (subqueries.nonEmpty) {
600+
// We currently only allow correlated subqueries in an aggregate if they are part of the
601+
// grouping expressions. As a result we need to replace all the scalar subqueries in the
602+
// grouping expressions by their result.
603+
val newGrouping = grouping.map { e =>
604+
subqueries.find(_._1.semanticEquals(e)).map(_._1.plan.output.head).getOrElse(e)
639605
}
640-
case f @ Filter(condition, child) =>
641-
val subqueries = ArrayBuffer.empty[ScalarSubquery]
642-
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
643-
if (subqueries.nonEmpty) {
644-
val (newChild, rewriteMap) = constructLeftJoins(child, subqueries)
645-
rewritePlanMap ++= rewriteMap
646-
Project(f.output, Filter(newCondition, newChild))
647-
} else {
648-
f
649-
}
650-
}
651-
652-
if (rewritePlanMap.nonEmpty) {
653-
assert(!plan.fastEquals(newPlan))
654-
Analyzer.rewritePlan(newPlan, rewritePlanMap.toMap)._1
655-
} else {
656-
newPlan
657-
}
606+
val newAgg = Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
607+
val attrMapping = a.output.zip(newAgg.output)
608+
newAgg -> attrMapping
609+
} else {
610+
a -> Nil
611+
}
612+
case p @ Project(expressions, child) =>
613+
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
614+
val newExpressions = expressions.map(extractCorrelatedScalarSubqueries(_, subqueries))
615+
if (subqueries.nonEmpty) {
616+
val newProj = Project(newExpressions, constructLeftJoins(child, subqueries))
617+
val attrMapping = p.output.zip(newProj.output)
618+
newProj -> attrMapping
619+
} else {
620+
p -> Nil
621+
}
622+
case f @ Filter(condition, child) =>
623+
val subqueries = ArrayBuffer.empty[(ScalarSubquery, ExprId)]
624+
val newCondition = extractCorrelatedScalarSubqueries(condition, subqueries)
625+
if (subqueries.nonEmpty) {
626+
val newProj = Project(f.output, Filter(newCondition, constructLeftJoins(child, subqueries)))
627+
val attrMapping = f.output.zip(newProj.output)
628+
newProj -> attrMapping
629+
} else {
630+
f -> Nil
631+
}
658632
}
659633
}

0 commit comments

Comments
 (0)