Skip to content

Commit 3eade74

Browse files
AngersZhuuuucloud-fan
authored andcommitted
[SPARK-29800][SQL] Rewrite non-correlated EXISTS subquery use ScalaSubquery to optimize perf
### What changes were proposed in this pull request? Current catalyst rewrite non-correlated exists subquery to BroadcastNestLoopJoin, it's performance is not good , now we rewrite non-correlated EXISTS subquery to ScalaSubquery to optimize the performance. We rewrite ``` WHERE EXISTS (SELECT A FROM TABLE B WHERE COL1 > 10) ``` to ``` WHERE (SELECT 1 FROM (SELECT A FROM TABLE B WHERE COL1 > 10) LIMIT 1) IS NOT NULL ``` to avoid build join to solve EXISTS expression. ### Why are the changes needed? Optimize EXISTS performance. ### Does this PR introduce any user-facing change? NO ### How was this patch tested? Manuel Tested Closes #26437 from AngersZhuuuu/SPARK-29800. Lead-authored-by: angerszhu <angers.zhu@gmail.com> Co-authored-by: AngersZhuuuu <angers.zhu@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent bc16bb1 commit 3eade74

File tree

6 files changed

+52
-10
lines changed

6 files changed

+52
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,11 +62,13 @@ abstract class SubqueryExpression(
6262

6363
object SubqueryExpression {
6464
/**
65-
* Returns true when an expression contains an IN or EXISTS subquery and false otherwise.
65+
* Returns true when an expression contains an IN or correlated EXISTS subquery
66+
* and false otherwise.
6667
*/
67-
def hasInOrExistsSubquery(e: Expression): Boolean = {
68+
def hasInOrCorrelatedExistsSubquery(e: Expression): Boolean = {
6869
e.find {
69-
case _: ListQuery | _: Exists => true
70+
case _: ListQuery => true
71+
case _: Exists if e.children.nonEmpty => true
7072
case _ => false
7173
}.isDefined
7274
}
@@ -302,7 +304,10 @@ case class ListQuery(
302304
}
303305

304306
/**
305-
* The [[Exists]] expression checks if a row exists in a subquery given some correlated condition.
307+
* The [[Exists]] expression checks if a row exists in a subquery given some correlated condition
308+
* or some uncorrelated condition.
309+
*
310+
* 1. correlated condition:
306311
*
307312
* For example (SQL):
308313
* {{{
@@ -312,6 +317,17 @@ case class ListQuery(
312317
* FROM b
313318
* WHERE b.id = a.id)
314319
* }}}
320+
*
321+
* 2. uncorrelated condition example:
322+
*
323+
* For example (SQL):
324+
* {{{
325+
* SELECT *
326+
* FROM a
327+
* WHERE EXISTS (SELECT *
328+
* FROM b
329+
* WHERE b.id > 10)
330+
* }}}
315331
*/
316332
case class Exists(
317333
plan: LogicalPlan,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,6 +128,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
128128
EliminateSubqueryAliases,
129129
EliminateView,
130130
ReplaceExpressions,
131+
RewriteNonCorrelatedExists,
131132
ComputeCurrentTime,
132133
GetCurrentDatabase(catalogManager),
133134
RewriteDistinctAggregates,

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/finishAnalysis.scala

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,21 @@ object ReplaceExpressions extends Rule[LogicalPlan] {
5252
}
5353
}
5454

55+
/**
56+
* Rewrite non correlated exists subquery to use ScalarSubquery
57+
* WHERE EXISTS (SELECT A FROM TABLE B WHERE COL1 > 10)
58+
* will be rewritten to
59+
* WHERE (SELECT 1 FROM (SELECT A FROM TABLE B WHERE COL1 > 10) LIMIT 1) IS NOT NULL
60+
*/
61+
object RewriteNonCorrelatedExists extends Rule[LogicalPlan] {
62+
override def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
63+
case exists: Exists if exists.children.isEmpty =>
64+
IsNotNull(
65+
ScalarSubquery(
66+
plan = Limit(Literal(1), Project(Seq(Alias(Literal(1), "col")()), exists.plan)),
67+
exprId = exists.exprId))
68+
}
69+
}
5570

5671
/**
5772
* Computes the current date and time to make sure we return the same result in a single query.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,8 @@ object RewritePredicateSubquery extends Rule[LogicalPlan] with PredicateHelper {
9696
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
9797
case Filter(condition, child) =>
9898
val (withSubquery, withoutSubquery) =
99-
splitConjunctivePredicates(condition).partition(SubqueryExpression.hasInOrExistsSubquery)
99+
splitConjunctivePredicates(condition)
100+
.partition(SubqueryExpression.hasInOrCorrelatedExistsSubquery)
100101

101102
// Construct the pruned filter condition.
102103
val newFilter: LogicalPlan = withoutSubquery match {

sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier
2828
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
2929
import org.apache.spark.sql.catalyst.plans.logical.{BROADCAST, Join, JoinStrategyHint, SHUFFLE_HASH}
3030
import org.apache.spark.sql.catalyst.util.DateTimeConstants
31-
import org.apache.spark.sql.execution.{RDDScanExec, SparkPlan}
31+
import org.apache.spark.sql.execution.{ExecSubqueryExpression, RDDScanExec, SparkPlan}
3232
import org.apache.spark.sql.execution.columnar._
3333
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
3434
import org.apache.spark.sql.functions._
@@ -89,10 +89,19 @@ class CachedTableSuite extends QueryTest with SQLTestUtils with SharedSparkSessi
8989
sum
9090
}
9191

92+
private def getNumInMemoryTablesInSubquery(plan: SparkPlan): Int = {
93+
plan.expressions.flatMap(_.collect {
94+
case sub: ExecSubqueryExpression => getNumInMemoryTablesRecursively(sub.plan)
95+
}).sum
96+
}
97+
9298
private def getNumInMemoryTablesRecursively(plan: SparkPlan): Int = {
9399
plan.collect {
94-
case InMemoryTableScanExec(_, _, relation) =>
95-
getNumInMemoryTablesRecursively(relation.cachedPlan) + 1
100+
case inMemoryTable @ InMemoryTableScanExec(_, _, relation) =>
101+
getNumInMemoryTablesRecursively(relation.cachedPlan) +
102+
getNumInMemoryTablesInSubquery(inMemoryTable) + 1
103+
case p =>
104+
getNumInMemoryTablesInSubquery(p)
96105
}.sum
97106
}
98107

sql/core/src/test/scala/org/apache/spark/sql/SubquerySuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -891,9 +891,9 @@ class SubquerySuite extends QueryTest with SharedSparkSession {
891891

892892
val sqlText =
893893
"""
894-
|SELECT * FROM t1
894+
|SELECT * FROM t1 a
895895
|WHERE
896-
|NOT EXISTS (SELECT * FROM t1)
896+
|NOT EXISTS (SELECT * FROM t1 b WHERE a.i = b.i)
897897
""".stripMargin
898898
val optimizedPlan = sql(sqlText).queryExecution.optimizedPlan
899899
val join = optimizedPlan.collectFirst { case j: Join => j }.get

0 commit comments

Comments
 (0)