Skip to content

Commit 1f2f75c

Browse files
committed
drop MultiScalarSubquery, use ScalarSubquery(CreateStruct()) instead
1 parent 2828345 commit 1f2f75c

File tree

13 files changed

+268
-374
lines changed

13 files changed

+268
-374
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 1 addition & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,8 @@ import scala.collection.mutable.ArrayBuffer
2222
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2323
import org.apache.spark.sql.catalyst.plans.QueryPlan
2424
import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
25-
import org.apache.spark.sql.catalyst.trees.LeafLike
2625
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXISTS_SUBQUERY, LIST_SUBQUERY,
27-
MULTI_SCALAR_SUBQUERY, PLAN_EXPRESSION, SCALAR_SUBQUERY, TreePattern}
26+
PLAN_EXPRESSION, SCALAR_SUBQUERY, TreePattern}
2827
import org.apache.spark.sql.types._
2928
import org.apache.spark.util.collection.BitSet
3029

@@ -268,33 +267,6 @@ object ScalarSubquery {
268267
}
269268
}
270269

271-
/**
272-
* A subquery that is capable to return multiple scalar values.
273-
*/
274-
case class MultiScalarSubquery(
275-
plan: LogicalPlan,
276-
exprId: ExprId = NamedExpression.newExprId)
277-
extends SubqueryExpression(plan, Seq.empty, exprId) with LeafLike[Expression] with Unevaluable {
278-
override def dataType: DataType = {
279-
assert(plan.schema.nonEmpty, "Multi-column scalar subquery should have columns")
280-
plan.schema
281-
}
282-
283-
override def nullable: Boolean = true
284-
285-
override def withNewPlan(plan: LogicalPlan): MultiScalarSubquery = copy(plan = plan)
286-
287-
override def toString: String = s"multi-scalar-subquery#${exprId.id}"
288-
289-
override lazy val canonicalized: Expression = {
290-
MultiScalarSubquery(
291-
plan.canonicalized,
292-
ExprId(0))
293-
}
294-
295-
final override def nodePatternsInternal: Seq[TreePattern] = Seq(MULTI_SCALAR_SUBQUERY)
296-
}
297-
298270
/**
299271
* A [[ListQuery]] expression defines the query which we want to search in an IN subquery
300272
* expression. It should and can only be used in conjunction with an IN expression.

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueries.scala

Lines changed: 31 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ import org.apache.spark.sql.catalyst.expressions._
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
2424
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LeafNode, LogicalPlan, Project}
2525
import org.apache.spark.sql.catalyst.rules.Rule
26-
import org.apache.spark.sql.catalyst.trees.TreePattern.{MULTI_SCALAR_SUBQUERY, SCALAR_SUBQUERY}
26+
import org.apache.spark.sql.catalyst.trees.TreePattern.SCALAR_SUBQUERY
2727

2828
/**
2929
* This rule tries to merge multiple non-correlated [[ScalarSubquery]]s into a
@@ -72,46 +72,54 @@ import org.apache.spark.sql.catalyst.trees.TreePattern.{MULTI_SCALAR_SUBQUERY, S
7272
object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
7373
def apply(plan: LogicalPlan): LogicalPlan = {
7474
if (conf.scalarSubqueryMergeEabled && conf.subqueryReuseEnabled) {
75-
val mergedSubqueries = ArrayBuffer.empty[LogicalPlan]
76-
removeReferences(mergeAndInsertReferences(plan, mergedSubqueries), mergedSubqueries)
75+
val mergedSubqueries = ArrayBuffer.empty[Project]
76+
removeReferences(mergeAndInsertReferences(plan, mergedSubqueries))
7777
} else {
7878
plan
7979
}
8080
}
8181

8282
private def mergeAndInsertReferences(
8383
plan: LogicalPlan,
84-
mergedSubqueries: ArrayBuffer[LogicalPlan]): LogicalPlan = {
85-
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY), ruleId) {
86-
case s: ScalarSubquery if s.children.isEmpty =>
87-
val (mergedPlan, ordinal) = mergeAndGetReference(s.plan, mergedSubqueries)
88-
GetStructField(MultiScalarSubquery(mergedPlan, s.exprId), ordinal)
84+
mergedSubqueries: ArrayBuffer[Project]): LogicalPlan = {
85+
plan.transformWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY), ruleId) {
86+
case o => o.transformExpressionsUpWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY), ruleId) {
87+
case s: ScalarSubquery if s.children.isEmpty =>
88+
val (mergedPlan, ordinal) = mergeAndGetReference(s.plan, mergedSubqueries)
89+
GetStructField(s.copy(plan = mergedPlan), ordinal)
90+
}
8991
}
9092
}
9193

9294
case class SubqueryReference(
9395
index: Int,
94-
mergedSubqueries: ArrayBuffer[LogicalPlan]) extends LeafNode {
96+
mergedSubqueries: ArrayBuffer[Project]) extends LeafNode {
9597
override def stringArgs: Iterator[Any] = Iterator(index)
9698

9799
override def output: Seq[Attribute] = mergedSubqueries(index).output
98100
}
99101

100102
private def mergeAndGetReference(
101103
plan: LogicalPlan,
102-
mergedSubqueries: ArrayBuffer[LogicalPlan]): (SubqueryReference, Int) = {
104+
mergedSubqueries: ArrayBuffer[Project]): (SubqueryReference, Int) = {
103105
mergedSubqueries.zipWithIndex.collectFirst {
104-
Function.unlift { case (s, i) => tryMergePlans(plan, s).map(_ -> i) }
105-
}.map { case ((mergedPlan, outputMap), i) =>
106-
mergedSubqueries(i) = mergedPlan
107-
SubqueryReference(i, mergedSubqueries) ->
108-
mergedPlan.output.indexOf(outputMap(plan.output.head))
106+
Function.unlift { case (header, i) => tryMergePlans(plan, header.child).map((header, _, i)) }
107+
}.map { case (header, (mergedPlan, outputMap), i) =>
108+
if (mergedPlan.output.size > header.child.output.size) {
109+
mergedSubqueries(i) = createHeader(mergedPlan)
110+
}
111+
val ordinal = mergedPlan.output.indexOf(outputMap(plan.output.head))
112+
SubqueryReference(i, mergedSubqueries) -> ordinal
109113
}.getOrElse {
110-
mergedSubqueries += plan
114+
mergedSubqueries += createHeader(plan)
111115
SubqueryReference(mergedSubqueries.length - 1, mergedSubqueries) -> 0
112116
}
113117
}
114118

119+
private def createHeader(plan: LogicalPlan) = {
120+
Project(Seq(Alias(CreateStruct(plan.output), "mergedValue")()), plan)
121+
}
122+
115123
private def tryMergePlans(
116124
newPlan: LogicalPlan,
117125
existingPlan: LogicalPlan): Option[(LogicalPlan, AttributeMap[Attribute])] = {
@@ -191,16 +199,14 @@ object MergeScalarSubqueries extends Rule[LogicalPlan] with PredicateHelper {
191199
}
192200
}
193201

194-
private def removeReferences(
195-
plan: LogicalPlan,
196-
mergedSubqueries: ArrayBuffer[LogicalPlan]): LogicalPlan = {
197-
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(MULTI_SCALAR_SUBQUERY), ruleId) {
198-
case gsf @ GetStructField(mss @ MultiScalarSubquery(sr: SubqueryReference, _), _, _) =>
199-
val dereferencedPlan = removeReferences(mergedSubqueries(sr.index), mergedSubqueries)
200-
if (dereferencedPlan.outputSet.size > 1) {
201-
gsf.copy(child = mss.copy(plan = dereferencedPlan))
202+
private def removeReferences(plan: LogicalPlan): LogicalPlan = {
203+
plan.transformAllExpressionsWithPruning(_.containsAnyPattern(SCALAR_SUBQUERY), ruleId) {
204+
case gsf @ GetStructField(ss @ ScalarSubquery(sr: SubqueryReference, _, _), _, _) =>
205+
val header = sr.mergedSubqueries(sr.index)
206+
if (header.child.output.size > 1) {
207+
gsf.copy(child = ss.copy(plan = header))
202208
} else {
203-
ScalarSubquery(dereferencedPlan, exprId = mss.exprId)
209+
ss.copy(plan = header.child)
204210
}
205211
}
206212
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreePatterns.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ object TreePattern extends Enumeration {
4848
val LIST_SUBQUERY: Value = Value
4949
val LITERAL: Value = Value
5050
val MAP_OBJECTS: Value = Value
51-
val MULTI_SCALAR_SUBQUERY: Value = Value
5251
val NOT: Value = Value
5352
val NULL_CHECK: Value = Value
5453
val NULL_LITERAL: Value = Value

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/MergeScalarSubqueriesSuite.scala

Lines changed: 35 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -19,14 +19,13 @@ package org.apache.spark.sql.catalyst.optimizer
1919

2020
import org.apache.spark.sql.catalyst.dsl.expressions._
2121
import org.apache.spark.sql.catalyst.dsl.plans._
22-
import org.apache.spark.sql.catalyst.expressions.{GetStructField, MultiScalarSubquery, ScalarSubquery}
22+
import org.apache.spark.sql.catalyst.expressions.{CreateStruct, GetStructField, ScalarSubquery}
2323
import org.apache.spark.sql.catalyst.expressions.aggregate.{CollectList, CollectSet}
2424
import org.apache.spark.sql.catalyst.plans._
2525
import org.apache.spark.sql.catalyst.plans.logical._
2626
import org.apache.spark.sql.catalyst.rules._
2727

2828
class MergeScalarSubqueriesSuite extends PlanTest {
29-
3029
private object Optimize extends RuleExecutor[LogicalPlan] {
3130
val batches = Batch("MergeScalarSubqueries", Once, MergeScalarSubqueries) :: Nil
3231
}
@@ -35,82 +34,81 @@ class MergeScalarSubqueriesSuite extends PlanTest {
3534

3635
test("Simple non-correlated scalar subquery merge") {
3736
val subquery1 = testRelation
38-
.groupBy('b)(max('a))
37+
.groupBy('b)(max('a).as("max_a"))
3938
val subquery2 = testRelation
40-
.groupBy('b)(sum('a))
39+
.groupBy('b)(sum('a).as("sum_a"))
4140
val originalQuery = testRelation
4241
.select(ScalarSubquery(subquery1), ScalarSubquery(subquery2))
4342

4443
val multiSubquery = testRelation
45-
.groupBy('b)(max('a), sum('a)).analyze
44+
.groupBy('b)(max('a).as("max_a"), sum('a).as("sum_a"))
45+
.select(CreateStruct(Seq('max_a, 'sum_a)).as("mergedValue"))
4646
val correctAnswer = testRelation
47-
.select(GetStructField(MultiScalarSubquery(multiSubquery), 0).as("scalarsubquery()"),
48-
GetStructField(MultiScalarSubquery(multiSubquery), 1).as("scalarsubquery()"))
47+
.select(GetStructField(ScalarSubquery(multiSubquery), 0).as("scalarsubquery()"),
48+
GetStructField(ScalarSubquery(multiSubquery), 1).as("scalarsubquery()"))
4949

50-
// checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
51-
// as only `Optimizer` can insert such a node to the plan
52-
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer, false)
50+
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
5351
}
5452

5553
test("Aggregate and group expression merge") {
5654
val subquery1 = testRelation
57-
.groupBy('b)(max('a))
55+
.groupBy('b)(max('a).as("max_a"))
5856
val subquery2 = testRelation
5957
.groupBy('b)('b)
6058
val originalQuery = testRelation
6159
.select(ScalarSubquery(subquery1), ScalarSubquery(subquery2))
6260

6361
val multiSubquery = testRelation
64-
.groupBy('b)(max('a), 'b).analyze
62+
.groupBy('b)(max('a).as("max_a"), 'b)
63+
.select(CreateStruct(Seq('max_a, 'b)).as("mergedValue"))
6564
val correctAnswer = testRelation
66-
.select(GetStructField(MultiScalarSubquery(multiSubquery), 0).as("scalarsubquery()"),
67-
GetStructField(MultiScalarSubquery(multiSubquery), 1).as("scalarsubquery()"))
65+
.select(GetStructField(ScalarSubquery(multiSubquery), 0).as("scalarsubquery()"),
66+
GetStructField(ScalarSubquery(multiSubquery), 1).as("scalarsubquery()"))
6867

69-
// checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
70-
// as only `Optimizer` can insert such a node to the plan
71-
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer, false)
68+
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
7269
}
7370

7471
test("Do not merge different aggregate implementations") {
7572
// supports HashAggregate
7673
val subquery1 = testRelation
77-
.groupBy('b)(max('a))
74+
.groupBy('b)(max('a).as("max_a"))
7875
val subquery2 = testRelation
79-
.groupBy('b)(min('a))
76+
.groupBy('b)(min('a).as("min_a"))
8077

8178
// supports ObjectHashAggregate
8279
val subquery3 = testRelation
83-
.groupBy('b)(CollectList('a).toAggregateExpression(isDistinct = false))
80+
.groupBy('b)(CollectList('a).toAggregateExpression(isDistinct = false).as("collectlist_a"))
8481
val subquery4 = testRelation
85-
.groupBy('b)(CollectSet('a).toAggregateExpression(isDistinct = false))
82+
.groupBy('b)(CollectSet('a).toAggregateExpression(isDistinct = false).as("collectset_a"))
8683

8784
// supports SortAggregate
8885
val subquery5 = testRelation
89-
.groupBy('b)(max('c))
86+
.groupBy('b)(max('c).as("max_c"))
9087
val subquery6 = testRelation
91-
.groupBy('b)(min('c))
88+
.groupBy('b)(min('c).as("min_c"))
9289

9390
val originalQuery = testRelation
9491
.select(ScalarSubquery(subquery1), ScalarSubquery(subquery2), ScalarSubquery(subquery3),
9592
ScalarSubquery(subquery4), ScalarSubquery(subquery5), ScalarSubquery(subquery6))
9693

9794
val hashAggregates = testRelation
98-
.groupBy('b)(max('a), min('a)).analyze
95+
.groupBy('b)(max('a).as("max_a"), min('a).as("min_a"))
96+
.select(CreateStruct(Seq('max_a, 'min_a)).as("mergedValue"))
9997
val objectHashAggregates = testRelation
100-
.groupBy('b)(CollectList('a).toAggregateExpression(isDistinct = false),
101-
CollectSet('a).toAggregateExpression(isDistinct = false)).analyze
98+
.groupBy('b)(CollectList('a).toAggregateExpression(isDistinct = false).as("collectlist_a"),
99+
CollectSet('a).toAggregateExpression(isDistinct = false).as("collectset_a"))
100+
.select(CreateStruct(Seq('collectlist_a, 'collectset_a)).as("mergedValue"))
102101
val sortAggregates = testRelation
103-
.groupBy('b)(max('c), min('c)).analyze
102+
.groupBy('b)(max('c).as("max_c"), min('c).as("min_c"))
103+
.select(CreateStruct(Seq('max_c, 'min_c)).as("mergedValue"))
104104
val correctAnswer = testRelation
105-
.select(GetStructField(MultiScalarSubquery(hashAggregates), 0).as("scalarsubquery()"),
106-
GetStructField(MultiScalarSubquery(hashAggregates), 1).as("scalarsubquery()"),
107-
GetStructField(MultiScalarSubquery(objectHashAggregates), 0).as("scalarsubquery()"),
108-
GetStructField(MultiScalarSubquery(objectHashAggregates), 1).as("scalarsubquery()"),
109-
GetStructField(MultiScalarSubquery(sortAggregates), 0).as("scalarsubquery()"),
110-
GetStructField(MultiScalarSubquery(sortAggregates), 1).as("scalarsubquery()"))
111-
112-
// checkAnalysis is disabled because `Analizer` is not prepared for `MultiScalarSubquery` nodes
113-
// as only `Optimizer` can insert such a node to the plan
114-
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer, false)
105+
.select(GetStructField(ScalarSubquery(hashAggregates), 0).as("scalarsubquery()"),
106+
GetStructField(ScalarSubquery(hashAggregates), 1).as("scalarsubquery()"),
107+
GetStructField(ScalarSubquery(objectHashAggregates), 0).as("scalarsubquery()"),
108+
GetStructField(ScalarSubquery(objectHashAggregates), 1).as("scalarsubquery()"),
109+
GetStructField(ScalarSubquery(sortAggregates), 0).as("scalarsubquery()"),
110+
GetStructField(ScalarSubquery(sortAggregates), 1).as("scalarsubquery()"))
111+
112+
comparePlans(Optimize.execute(originalQuery.analyze), correctAnswer.analyze)
115113
}
116114
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/plans/PlanTest.scala

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ trait PlanTestBase extends PredicateHelper with SQLHelper with SQLConfHelper { s
7373
plan transformAllExpressions {
7474
case s: ScalarSubquery =>
7575
s.copy(plan = normalizeExprIds(s.plan), exprId = ExprId(0))
76-
case s: MultiScalarSubquery =>
77-
s.copy(plan = normalizeExprIds(s.plan), exprId = ExprId(0))
7876
case e: Exists =>
7977
e.copy(exprId = ExprId(0))
8078
case l: ListQuery =>

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/InsertAdaptiveSparkPlan.scala

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,7 @@ import org.apache.spark.sql.catalyst.expressions.{ListQuery, SubqueryExpression}
2424
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2525
import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution
2626
import org.apache.spark.sql.catalyst.rules.Rule
27-
import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY,
28-
MULTI_SCALAR_SUBQUERY, SCALAR_SUBQUERY}
27+
import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY}
2928
import org.apache.spark.sql.execution._
3029
import org.apache.spark.sql.execution.command.{DataWritingCommandExec, ExecutedCommandExec}
3130
import org.apache.spark.sql.execution.datasources.v2.V2CommandExec
@@ -115,8 +114,7 @@ case class InsertAdaptiveSparkPlan(
115114
*/
116115
private def buildSubqueryMap(plan: SparkPlan): Map[Long, BaseSubqueryExec] = {
117116
val subqueryMap = mutable.HashMap.empty[Long, BaseSubqueryExec]
118-
if (!plan.containsAnyPattern(SCALAR_SUBQUERY, MULTI_SCALAR_SUBQUERY, IN_SUBQUERY,
119-
DYNAMIC_PRUNING_SUBQUERY)) {
117+
if (!plan.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
120118
return subqueryMap.toMap
121119
}
122120
plan.foreach(_.expressions.foreach(_.foreach {
@@ -127,13 +125,6 @@ case class InsertAdaptiveSparkPlan(
127125
val subquery = SubqueryExec.createForScalarSubquery(
128126
s"subquery#${exprId.id}", executedPlan)
129127
subqueryMap.put(exprId.id, subquery)
130-
case expressions.MultiScalarSubquery(p, exprId)
131-
if !subqueryMap.contains(exprId.id) =>
132-
val executedPlan = compileSubquery(p)
133-
verifyAdaptivePlan(executedPlan, p)
134-
val subquery = SubqueryExec.createForScalarSubquery(
135-
s"subquery#${exprId.id}", executedPlan)
136-
subqueryMap.put(exprId.id, subquery)
137128
case expressions.InSubquery(_, ListQuery(query, _, exprId, _))
138129
if !subqueryMap.contains(exprId.id) =>
139130
val executedPlan = compileSubquery(query)

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/PlanAdaptiveSubqueries.scala

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,7 @@ package org.apache.spark.sql.execution.adaptive
2020
import org.apache.spark.sql.catalyst.expressions
2121
import org.apache.spark.sql.catalyst.expressions.{CreateNamedStruct, DynamicPruningExpression, ListQuery, Literal}
2222
import org.apache.spark.sql.catalyst.rules.Rule
23-
import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY,
24-
MULTI_SCALAR_SUBQUERY, SCALAR_SUBQUERY}
23+
import org.apache.spark.sql.catalyst.trees.TreePattern.{DYNAMIC_PRUNING_SUBQUERY, IN_SUBQUERY, SCALAR_SUBQUERY}
2524
import org.apache.spark.sql.execution
2625
import org.apache.spark.sql.execution.{BaseSubqueryExec, InSubqueryExec, SparkPlan}
2726

@@ -30,12 +29,9 @@ case class PlanAdaptiveSubqueries(
3029

3130
def apply(plan: SparkPlan): SparkPlan = {
3231
plan.transformAllExpressionsWithPruning(
33-
_.containsAnyPattern(SCALAR_SUBQUERY, MULTI_SCALAR_SUBQUERY, IN_SUBQUERY,
34-
DYNAMIC_PRUNING_SUBQUERY)) {
32+
_.containsAnyPattern(SCALAR_SUBQUERY, IN_SUBQUERY, DYNAMIC_PRUNING_SUBQUERY)) {
3533
case expressions.ScalarSubquery(_, _, exprId) =>
3634
execution.ScalarSubquery(subqueryMap(exprId.id), exprId)
37-
case expressions.MultiScalarSubquery(_, exprId) =>
38-
execution.MultiScalarSubqueryExec(subqueryMap(exprId.id), exprId)
3935
case expressions.InSubquery(values, ListQuery(_, _, exprId, _)) =>
4036
val expr = if (values.length == 1) {
4137
values.head

0 commit comments

Comments
 (0)