Skip to content

Commit de8e4be

Browse files
allisonwang-dbcloud-fan
authored andcommitted
[SPARK-36063][SQL] Optimize OneRowRelation subqueries
### What changes were proposed in this pull request? This PR adds optimization for scalar and lateral subqueries with OneRowRelation as leaf nodes. It inlines such subqueries before decorrelation to avoid rewriting them as left outer joins. It also introduces a flag to turn on/off this optimization: `spark.sql.optimizer.optimizeOneRowRelationSubquery` (default: True). For example: ```sql select (select c1) from t ``` Analyzed plan: ``` Project [scalar-subquery#17 [c1#18] AS scalarsubquery(c1)#22] : +- Project [outer(c1#18)] : +- OneRowRelation +- LocalRelation [c1#18, c2#19] ``` Optimized plan before this PR: ``` Project [c1#18#25 AS scalarsubquery(c1)#22] +- Join LeftOuter, (c1#24 <=> c1#18) :- LocalRelation [c1#18] +- Aggregate [c1#18], [c1#18 AS c1#18#25, c1#18 AS c1#24] +- LocalRelation [c1#18] ``` Optimized plan after this PR: ``` LocalRelation [scalarsubquery(c1)#22] ``` ### Why are the changes needed? To optimize query plans. ### Does this PR introduce _any_ user-facing change? No. ### How was this patch tested? Added new unit tests. Closes #33284 from allisonwang-db/spark-36063-optimize-subquery-one-row-relation. Authored-by: allisonwang-db <allison.wang@databricks.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent dcb7db5 commit de8e4be

File tree

10 files changed

+283
-16
lines changed

10 files changed

+283
-16
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -390,6 +390,13 @@ package object dsl {
390390
condition: Option[Expression] = None): LogicalPlan =
391391
Join(logicalPlan, otherPlan, joinType, condition, JoinHint.NONE)
392392

393+
def lateralJoin(
394+
otherPlan: LogicalPlan,
395+
joinType: JoinType = Inner,
396+
condition: Option[Expression] = None): LogicalPlan = {
397+
LateralJoin(logicalPlan, LateralSubquery(otherPlan), joinType, condition)
398+
}
399+
393400
def cogroup[Key: Encoder, Left: Encoder, Right: Encoder, Result: Encoder](
394401
otherPlan: LogicalPlan,
395402
func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,15 @@ object SubExprUtils extends PredicateHelper {
126126
/**
127127
* Returns an expression after removing the OuterReference shell.
128128
*/
129-
def stripOuterReference(e: Expression): Expression = e.transform { case OuterReference(r) => r }
129+
def stripOuterReference[E <: Expression](e: E): E = {
130+
e.transform { case OuterReference(r) => r }.asInstanceOf[E]
131+
}
130132

131133
/**
132134
* Returns the list of expressions after removing the OuterReference shell from each of
133135
* the expression.
134136
*/
135-
def stripOuterReferences(e: Seq[Expression]): Seq[Expression] = e.map(stripOuterReference)
137+
def stripOuterReferences[E <: Expression](e: Seq[E]): Seq[E] = e.map(stripOuterReference)
136138

137139
/**
138140
* Returns the logical plan after removing the OuterReference shell from all the expressions

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,23 @@ object DecorrelateInnerQuery extends PredicateHelper {
156156
expressions.map(replaceOuterReference(_, outerReferenceMap))
157157
}
158158

159+
/**
160+
* Replace all outer references in the given named expressions and keep the output
161+
* attributes unchanged.
162+
*/
163+
private def replaceOuterInNamedExpressions(
164+
expressions: Seq[NamedExpression],
165+
outerReferenceMap: AttributeMap[Attribute]): Seq[NamedExpression] = {
166+
expressions.map { expr =>
167+
val newExpr = replaceOuterReference(expr, outerReferenceMap)
168+
if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) {
169+
Alias(newExpr, expr.name)(expr.exprId)
170+
} else {
171+
newExpr
172+
}
173+
}
174+
}
175+
159176
/**
160177
* Return all references that are presented in the join conditions but not in the output
161178
* of the given named expressions.
@@ -429,8 +446,9 @@ object DecorrelateInnerQuery extends PredicateHelper {
429446
val newOuterReferences = parentOuterReferences ++ outerReferences
430447
val (newChild, joinCond, outerReferenceMap) =
431448
decorrelate(child, newOuterReferences, aggregated)
432-
// Replace all outer references in the original project list.
433-
val newProjectList = replaceOuterReferences(projectList, outerReferenceMap)
449+
// Replace all outer references in the original project list and keep the output
450+
// attributes unchanged.
451+
val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap)
434452
// Preserve required domain attributes in the join condition by adding the missing
435453
// references to the new project list.
436454
val referencesToAdd = missingReferences(newProjectList, joinCond)
@@ -442,9 +460,10 @@ object DecorrelateInnerQuery extends PredicateHelper {
442460
val newOuterReferences = parentOuterReferences ++ outerReferences
443461
val (newChild, joinCond, outerReferenceMap) =
444462
decorrelate(child, newOuterReferences, aggregated = true)
445-
// Replace all outer references in grouping and aggregate expressions.
463+
// Replace all outer references in grouping and aggregate expressions, and keep
464+
// the output attributes unchanged.
446465
val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap)
447-
val newAggExpr = replaceOuterReferences(aggregateExpressions, outerReferenceMap)
466+
val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap)
448467
// Add all required domain attributes to both grouping and aggregate expressions.
449468
val referencesToAdd = missingReferences(newAggExpr, joinCond)
450469
val newAggregate = a.copy(

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ abstract class Optimizer(catalogManager: CatalogManager)
179179
// non-nullable when an empty relation child of a Union is removed
180180
UpdateAttributeNullability) ::
181181
Batch("Pullup Correlated Expressions", Once,
182+
OptimizeOneRowRelationSubquery,
182183
PullupCorrelatedPredicates) ::
183184
// Subquery batch applies the optimizer rules recursively. Therefore, it makes no sense
184185
// to enforce idempotence on it and we change this batch from Once to FixedPoint(1).

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/subquery.scala

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import scala.collection.mutable.ArrayBuffer
2121

22+
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
2223
import org.apache.spark.sql.catalyst.expressions._
2324
import org.apache.spark.sql.catalyst.expressions.ScalarSubquery._
2425
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
@@ -711,3 +712,47 @@ object RewriteLateralSubquery extends Rule[LogicalPlan] {
711712
Join(left, newRight, joinType, newCond, JoinHint.NONE)
712713
}
713714
}
715+
716+
/**
717+
* This rule optimizes subqueries with OneRowRelation as leaf nodes.
718+
*/
719+
object OptimizeOneRowRelationSubquery extends Rule[LogicalPlan] {
720+
721+
object OneRowSubquery {
722+
def unapply(plan: LogicalPlan): Option[Seq[NamedExpression]] = {
723+
CollapseProject(EliminateSubqueryAliases(plan)) match {
724+
case Project(projectList, _: OneRowRelation) => Some(stripOuterReferences(projectList))
725+
case _ => None
726+
}
727+
}
728+
}
729+
730+
private def hasCorrelatedSubquery(plan: LogicalPlan): Boolean = {
731+
plan.find(_.expressions.exists(SubqueryExpression.hasCorrelatedSubquery)).isDefined
732+
}
733+
734+
/**
735+
* Rewrite a subquery expression into one or more expressions. The rewrite can only be done
736+
* if there is no nested subqueries in the subquery plan.
737+
*/
738+
private def rewrite(plan: LogicalPlan): LogicalPlan = plan.transformUpWithSubqueries {
739+
case LateralJoin(left, right @ LateralSubquery(OneRowSubquery(projectList), _, _, _), _, None)
740+
if !hasCorrelatedSubquery(right.plan) && right.joinCond.isEmpty =>
741+
Project(left.output ++ projectList, left)
742+
case p: LogicalPlan => p.transformExpressionsUpWithPruning(
743+
_.containsPattern(SCALAR_SUBQUERY)) {
744+
case s @ ScalarSubquery(OneRowSubquery(projectList), _, _, _)
745+
if !hasCorrelatedSubquery(s.plan) && s.joinCond.isEmpty =>
746+
assert(projectList.size == 1)
747+
projectList.head
748+
}
749+
}
750+
751+
def apply(plan: LogicalPlan): LogicalPlan = {
752+
if (!conf.getConf(SQLConf.OPTIMIZE_ONE_ROW_RELATION_SUBQUERY)) {
753+
plan
754+
} else {
755+
rewrite(plan)
756+
}
757+
}
758+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -435,6 +435,23 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]]
435435
subqueries ++ subqueries.flatMap(_.subqueriesAll)
436436
}
437437

438+
/**
439+
* Returns a copy of this node where the given partial function has been recursively applied
440+
* first to the subqueries in this node's children, then this node's children, and finally
441+
* this node itself (post-order). When the partial function does not apply to a given node,
442+
* it is left unchanged.
443+
*/
444+
def transformUpWithSubqueries(f: PartialFunction[PlanType, PlanType]): PlanType = {
445+
transformUp { case plan =>
446+
val transformed = plan transformExpressionsUp {
447+
case planExpression: PlanExpression[PlanType] =>
448+
val newPlan = planExpression.plan.transformUpWithSubqueries(f)
449+
planExpression.withNewPlan(newPlan)
450+
}
451+
f.applyOrElse[PlanType, PlanType](transformed, identity)
452+
}
453+
}
454+
438455
/**
439456
* A variant of `collect`. This method not only apply the given function to all elements in this
440457
* plan, also considering all the plans in its (nested) subqueries

sql/catalyst/src/main/scala/org/apache/spark/sql/internal/SQLConf.scala

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2613,6 +2613,14 @@ object SQLConf {
26132613
.booleanConf
26142614
.createWithDefault(true)
26152615

2616+
val OPTIMIZE_ONE_ROW_RELATION_SUBQUERY =
2617+
buildConf("spark.sql.optimizer.optimizeOneRowRelationSubquery")
2618+
.internal()
2619+
.doc("When true, the optimizer will inline subqueries with OneRowRelation as leaf nodes.")
2620+
.version("3.2.0")
2621+
.booleanConf
2622+
.createWithDefault(true)
2623+
26162624
val TOP_K_SORT_FALLBACK_THRESHOLD =
26172625
buildConf("spark.sql.execution.topKSortFallbackThreshold")
26182626
.internal()

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuerySuite.scala

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ class DecorrelateInnerQuerySuite extends PlanTest {
3232
val x = AttributeReference("x", IntegerType)()
3333
val y = AttributeReference("y", IntegerType)()
3434
val z = AttributeReference("z", IntegerType)()
35+
val t0 = OneRowRelation()
3536
val testRelation = LocalRelation(a, b, c)
3637
val testRelation2 = LocalRelation(x, y, z)
3738

@@ -203,23 +204,24 @@ class DecorrelateInnerQuerySuite extends PlanTest {
203204

204205
test("correlated values in project") {
205206
val outerPlan = testRelation2
206-
val innerPlan = Project(Seq(OuterReference(x), OuterReference(y)), OneRowRelation())
207-
val correctAnswer = Project(Seq(x, y), DomainJoin(Seq(x, y), OneRowRelation()))
207+
val innerPlan = Project(Seq(OuterReference(x).as("x1"), OuterReference(y).as("y1")), t0)
208+
val correctAnswer = Project(
209+
Seq(x.as("x1"), y.as("y1"), x, y), DomainJoin(Seq(x, y), t0))
208210
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
209211
}
210212

211213
test("correlated values in project with alias") {
212214
val outerPlan = testRelation2
213215
val innerPlan =
214-
Project(Seq(OuterReference(x), 'y1, 'sum),
216+
Project(Seq(OuterReference(x).as("x1"), 'y1, 'sum),
215217
Project(Seq(
216218
OuterReference(x),
217219
OuterReference(y).as("y1"),
218220
Add(OuterReference(x), OuterReference(y)).as("sum")),
219221
testRelation)).analyze
220222
val correctAnswer =
221-
Project(Seq(x, 'y1, 'sum, y),
222-
Project(Seq(x, y.as("y1"), (x + y).as("sum"), y),
223+
Project(Seq(x.as("x1"), 'y1, 'sum, x, y),
224+
Project(Seq(x.as(x.name), y.as("y1"), (x + y).as("sum"), x, y),
223225
DomainJoin(Seq(x, y), testRelation))).analyze
224226
check(innerPlan, outerPlan, correctAnswer, Seq(x <=> x, y <=> y))
225227
}
@@ -228,28 +230,28 @@ class DecorrelateInnerQuerySuite extends PlanTest {
228230
val outerPlan = testRelation2
229231
val innerPlan =
230232
Project(
231-
Seq(OuterReference(x)),
233+
Seq(OuterReference(x).as("x1")),
232234
Filter(
233235
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
234236
testRelation
235237
)
236238
)
237-
val correctAnswer = Project(Seq(a, c), Filter(b === 1, testRelation))
239+
val correctAnswer = Project(Seq(a.as("x1"), a, c), Filter(b === 1, testRelation))
238240
check(innerPlan, outerPlan, correctAnswer, Seq(x === a, x + y === c))
239241
}
240242

241243
test("correlated values in project without correlated equality conditions in filter") {
242244
val outerPlan = testRelation2
243245
val innerPlan =
244246
Project(
245-
Seq(OuterReference(y)),
247+
Seq(OuterReference(y).as("y1")),
246248
Filter(
247249
And(OuterReference(x) === a, And(OuterReference(x) + OuterReference(y) === c, b === 1)),
248250
testRelation
249251
)
250252
)
251253
val correctAnswer =
252-
Project(Seq(y, a, c),
254+
Project(Seq(y.as("y1"), y, a, c),
253255
Filter(b === 1,
254256
DomainJoin(Seq(y), testRelation)
255257
)

0 commit comments

Comments
 (0)