Skip to content

Commit 0b0c8b9

Browse files
hvanhovelldavies
authored andcommitted
[SPARK-17106] [SQL] Simplify the SubqueryExpression interface
## What changes were proposed in this pull request? The current subquery expression interface contains a little bit of technical debt in the form of a few different access paths to get and set the query contained by the expression. This is confusing to anyone who goes over this code. This PR unifies these access paths. ## How was this patch tested? (Existing tests) Author: Herman van Hovell <hvanhovell@databricks.com> Closes #14685 from hvanhovell/SPARK-17106.
1 parent 56d8674 commit 0b0c8b9

File tree

8 files changed

+56
-74
lines changed

8 files changed

+56
-74
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,7 @@ class Analyzer(
146146
// This cannot be done in ResolveSubquery because ResolveSubquery does not know the CTE.
147147
other transformExpressions {
148148
case e: SubqueryExpression =>
149-
e.withNewPlan(substituteCTE(e.query, cteRelations))
149+
e.withNewPlan(substituteCTE(e.plan, cteRelations))
150150
}
151151
}
152152
}
@@ -1091,7 +1091,7 @@ class Analyzer(
10911091
f: (LogicalPlan, Seq[Expression]) => SubqueryExpression): SubqueryExpression = {
10921092
// Step 1: Resolve the outer expressions.
10931093
var previous: LogicalPlan = null
1094-
var current = e.query
1094+
var current = e.plan
10951095
do {
10961096
// Try to resolve the subquery plan using the regular analyzer.
10971097
previous = current

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 28 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,33 +17,33 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20-
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
2120
import org.apache.spark.sql.catalyst.plans.QueryPlan
22-
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, SubqueryAlias}
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2322
import org.apache.spark.sql.types._
2423

2524
/**
26-
* An interface for subquery that is used in expressions.
25+
* An interface for expressions that contain a [[QueryPlan]].
2726
*/
28-
abstract class SubqueryExpression extends Expression {
27+
abstract class PlanExpression[T <: QueryPlan[_]] extends Expression {
2928
/** The id of the subquery expression. */
3029
def exprId: ExprId
3130

32-
/** The logical plan of the query. */
33-
def query: LogicalPlan
31+
/** The plan being wrapped in the query. */
32+
def plan: T
3433

35-
/**
36-
* Either a logical plan or a physical plan. The generated tree string (explain output) uses this
37-
* field to explain the subquery.
38-
*/
39-
def plan: QueryPlan[_]
40-
41-
/** Updates the query with new logical plan. */
42-
def withNewPlan(plan: LogicalPlan): SubqueryExpression
34+
/** Updates the expression with a new plan. */
35+
def withNewPlan(plan: T): PlanExpression[T]
4336

4437
protected def conditionString: String = children.mkString("[", " && ", "]")
4538
}
4639

40+
/**
41+
* A base interface for expressions that contain a [[LogicalPlan]].
42+
*/
43+
abstract class SubqueryExpression extends PlanExpression[LogicalPlan] {
44+
override def withNewPlan(plan: LogicalPlan): SubqueryExpression
45+
}
46+
4747
object SubqueryExpression {
4848
def hasCorrelatedSubquery(e: Expression): Boolean = {
4949
e.find {
@@ -60,20 +60,19 @@ object SubqueryExpression {
6060
* Note: `exprId` is used to have a unique name in explain string output.
6161
*/
6262
case class ScalarSubquery(
63-
query: LogicalPlan,
63+
plan: LogicalPlan,
6464
children: Seq[Expression] = Seq.empty,
6565
exprId: ExprId = NamedExpression.newExprId)
6666
extends SubqueryExpression with Unevaluable {
67-
override lazy val resolved: Boolean = childrenResolved && query.resolved
67+
override lazy val resolved: Boolean = childrenResolved && plan.resolved
6868
override lazy val references: AttributeSet = {
69-
if (query.resolved) super.references -- query.outputSet
69+
if (plan.resolved) super.references -- plan.outputSet
7070
else super.references
7171
}
72-
override def dataType: DataType = query.schema.fields.head.dataType
72+
override def dataType: DataType = plan.schema.fields.head.dataType
7373
override def foldable: Boolean = false
7474
override def nullable: Boolean = true
75-
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
76-
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(query = plan)
75+
override def withNewPlan(plan: LogicalPlan): ScalarSubquery = copy(plan = plan)
7776
override def toString: String = s"scalar-subquery#${exprId.id} $conditionString"
7877
}
7978

@@ -92,19 +91,18 @@ object ScalarSubquery {
9291
* be rewritten into a left semi/anti join during analysis.
9392
*/
9493
case class PredicateSubquery(
95-
query: LogicalPlan,
94+
plan: LogicalPlan,
9695
children: Seq[Expression] = Seq.empty,
9796
nullAware: Boolean = false,
9897
exprId: ExprId = NamedExpression.newExprId)
9998
extends SubqueryExpression with Predicate with Unevaluable {
100-
override lazy val resolved = childrenResolved && query.resolved
101-
override lazy val references: AttributeSet = super.references -- query.outputSet
99+
override lazy val resolved = childrenResolved && plan.resolved
100+
override lazy val references: AttributeSet = super.references -- plan.outputSet
102101
override def nullable: Boolean = nullAware
103-
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
104-
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
102+
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(plan = plan)
105103
override def semanticEquals(o: Expression): Boolean = o match {
106104
case p: PredicateSubquery =>
107-
query.sameResult(p.query) && nullAware == p.nullAware &&
105+
plan.sameResult(p.plan) && nullAware == p.nullAware &&
108106
children.length == p.children.length &&
109107
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
110108
case _ => false
@@ -146,14 +144,13 @@ object PredicateSubquery {
146144
* FROM b)
147145
* }}}
148146
*/
149-
case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
147+
case class ListQuery(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
150148
extends SubqueryExpression with Unevaluable {
151149
override lazy val resolved = false
152150
override def children: Seq[Expression] = Seq.empty
153151
override def dataType: DataType = ArrayType(NullType)
154152
override def nullable: Boolean = false
155-
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(query = plan)
156-
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
153+
override def withNewPlan(plan: LogicalPlan): ListQuery = copy(plan = plan)
157154
override def toString: String = s"list#${exprId.id}"
158155
}
159156

@@ -168,12 +165,11 @@ case class ListQuery(query: LogicalPlan, exprId: ExprId = NamedExpression.newExp
168165
* WHERE b.id = a.id)
169166
* }}}
170167
*/
171-
case class Exists(query: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
168+
case class Exists(plan: LogicalPlan, exprId: ExprId = NamedExpression.newExprId)
172169
extends SubqueryExpression with Predicate with Unevaluable {
173170
override lazy val resolved = false
174171
override def children: Seq[Expression] = Seq.empty
175172
override def nullable: Boolean = false
176-
override def withNewPlan(plan: LogicalPlan): Exists = copy(query = plan)
177-
override def plan: LogicalPlan = SubqueryAlias(toString, query, None)
173+
override def withNewPlan(plan: LogicalPlan): Exists = copy(plan = plan)
178174
override def toString: String = s"exists#${exprId.id}"
179175
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,7 @@ abstract class Optimizer(sessionCatalog: SessionCatalog, conf: CatalystConf)
127127
object OptimizeSubqueries extends Rule[LogicalPlan] {
128128
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
129129
case s: SubqueryExpression =>
130-
s.withNewPlan(Optimizer.this.execute(s.query))
130+
s.withNewPlan(Optimizer.this.execute(s.plan))
131131
}
132132
}
133133
}
@@ -1814,7 +1814,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
18141814
val newExpression = expression transform {
18151815
case s: ScalarSubquery if s.children.nonEmpty =>
18161816
subqueries += s
1817-
s.query.output.head
1817+
s.plan.output.head
18181818
}
18191819
newExpression.asInstanceOf[E]
18201820
}
@@ -2029,7 +2029,7 @@ object RewriteCorrelatedScalarSubquery extends Rule[LogicalPlan] {
20292029
// grouping expressions. As a result we need to replace all the scalar subqueries in the
20302030
// grouping expressions by their result.
20312031
val newGrouping = grouping.map { e =>
2032-
subqueries.find(_.semanticEquals(e)).map(_.query.output.head).getOrElse(e)
2032+
subqueries.find(_.semanticEquals(e)).map(_.plan.output.head).getOrElse(e)
20332033
}
20342034
Aggregate(newGrouping, newExpressions, constructLeftJoins(child, subqueries))
20352035
} else {

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,7 +263,9 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
263263
* All the subqueries of current plan.
264264
*/
265265
def subqueries: Seq[PlanType] = {
266-
expressions.flatMap(_.collect {case e: SubqueryExpression => e.plan.asInstanceOf[PlanType]})
266+
expressions.flatMap(_.collect {
267+
case e: PlanExpression[_] => e.plan.asInstanceOf[PlanType]
268+
})
267269
}
268270

269271
override protected def innerChildren: Seq[QueryPlan[_]] = subqueries

sql/core/src/main/scala/org/apache/spark/sql/catalyst/SQLBuilder.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ class SQLBuilder private (
8080
try {
8181
val replaced = finalPlan.transformAllExpressions {
8282
case s: SubqueryExpression =>
83-
val query = new SQLBuilder(s.query, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL
83+
val query = new SQLBuilder(s.plan, nextSubqueryId, nextGenAttrId, exprIdMap).toSQL
8484
val sql = s match {
8585
case _: ListQuery => query
8686
case _: Exists => s"EXISTS($query)"

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 17 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -22,28 +22,16 @@ import scala.collection.mutable.ArrayBuffer
2222

2323
import org.apache.spark.sql.SparkSession
2424
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
25-
import org.apache.spark.sql.catalyst.expressions._
25+
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, InSet, Literal, PlanExpression}
2626
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
27-
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2827
import org.apache.spark.sql.catalyst.rules.Rule
2928
import org.apache.spark.sql.internal.SQLConf
3029
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
3130

3231
/**
3332
* The base class for subquery that is used in SparkPlan.
3433
*/
35-
trait ExecSubqueryExpression extends SubqueryExpression {
36-
37-
val executedPlan: SubqueryExec
38-
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression
39-
40-
// does not have logical plan
41-
override def query: LogicalPlan = throw new UnsupportedOperationException
42-
override def withNewPlan(plan: LogicalPlan): SubqueryExpression =
43-
throw new UnsupportedOperationException
44-
45-
override def plan: SparkPlan = executedPlan
46-
34+
abstract class ExecSubqueryExpression extends PlanExpression[SubqueryExec] {
4735
/**
4836
* Fill the expression with collected result from executed plan.
4937
*/
@@ -56,30 +44,29 @@ trait ExecSubqueryExpression extends SubqueryExpression {
5644
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
5745
*/
5846
case class ScalarSubquery(
59-
executedPlan: SubqueryExec,
47+
plan: SubqueryExec,
6048
exprId: ExprId)
6149
extends ExecSubqueryExpression {
6250

63-
override def dataType: DataType = executedPlan.schema.fields.head.dataType
51+
override def dataType: DataType = plan.schema.fields.head.dataType
6452
override def children: Seq[Expression] = Nil
6553
override def nullable: Boolean = true
66-
override def toString: String = executedPlan.simpleString
67-
68-
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
54+
override def toString: String = plan.simpleString
55+
override def withNewPlan(query: SubqueryExec): ScalarSubquery = copy(plan = query)
6956

7057
override def semanticEquals(other: Expression): Boolean = other match {
71-
case s: ScalarSubquery => executedPlan.sameResult(executedPlan)
58+
case s: ScalarSubquery => plan.sameResult(s.plan)
7259
case _ => false
7360
}
7461

7562
// the first column in first row from `query`.
76-
@volatile private var result: Any = null
63+
@volatile private var result: Any = _
7764
@volatile private var updated: Boolean = false
7865

7966
def updateResult(): Unit = {
8067
val rows = plan.executeCollect()
8168
if (rows.length > 1) {
82-
sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
69+
sys.error(s"more than one row returned by a subquery used as an expression:\n$plan")
8370
}
8471
if (rows.length == 1) {
8572
assert(rows(0).numFields == 1,
@@ -108,21 +95,19 @@ case class ScalarSubquery(
10895
*/
10996
case class InSubquery(
11097
child: Expression,
111-
executedPlan: SubqueryExec,
98+
plan: SubqueryExec,
11299
exprId: ExprId,
113100
private var result: Array[Any] = null,
114101
private var updated: Boolean = false) extends ExecSubqueryExpression {
115102

116103
override def dataType: DataType = BooleanType
117104
override def children: Seq[Expression] = child :: Nil
118105
override def nullable: Boolean = child.nullable
119-
override def toString: String = s"$child IN ${executedPlan.name}"
120-
121-
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
106+
override def toString: String = s"$child IN ${plan.name}"
107+
override def withNewPlan(plan: SubqueryExec): InSubquery = copy(plan = plan)
122108

123109
override def semanticEquals(other: Expression): Boolean = other match {
124-
case in: InSubquery => child.semanticEquals(in.child) &&
125-
executedPlan.sameResult(in.executedPlan)
110+
case in: InSubquery => child.semanticEquals(in.child) && plan.sameResult(in.plan)
126111
case _ => false
127112
}
128113

@@ -159,8 +144,8 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
159144
ScalarSubquery(
160145
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
161146
subquery.exprId)
162-
case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) =>
163-
val executedPlan = new QueryExecution(sparkSession, plan).executedPlan
147+
case expressions.PredicateSubquery(query, Seq(e: Expression), _, exprId) =>
148+
val executedPlan = new QueryExecution(sparkSession, query).executedPlan
164149
InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
165150
}
166151
}
@@ -184,9 +169,9 @@ case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {
184169
val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
185170
val sameResult = sameSchema.find(_.sameResult(sub.plan))
186171
if (sameResult.isDefined) {
187-
sub.withExecutedPlan(sameResult.get)
172+
sub.withNewPlan(sameResult.get)
188173
} else {
189-
sameSchema += sub.executedPlan
174+
sameSchema += sub.plan
190175
sub
191176
}
192177
}

sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -292,7 +292,7 @@ abstract class QueryTest extends PlanTest {
292292
p.expressions.foreach {
293293
_.foreach {
294294
case s: SubqueryExpression =>
295-
s.query.foreach(collectData)
295+
s.plan.foreach(collectData)
296296
case _ =>
297297
}
298298
}
@@ -334,7 +334,7 @@ abstract class QueryTest extends PlanTest {
334334
case p =>
335335
p.transformExpressions {
336336
case s: SubqueryExpression =>
337-
s.withNewPlan(s.query.transformDown(renormalize))
337+
s.withNewPlan(s.plan.transformDown(renormalize))
338338
}
339339
}
340340
val normalized2 = jsonBackPlan.transformDown(renormalize)

sql/core/src/test/scala/org/apache/spark/sql/execution/benchmark/TPCDSQueryBenchmark.scala

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
2626
import org.apache.spark.sql.catalyst.expressions.SubqueryExpression
2727
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2828
import org.apache.spark.sql.catalyst.util._
29-
import org.apache.spark.sql.internal.SQLConf
3029
import org.apache.spark.util.Benchmark
3130

3231
/**

0 commit comments

Comments
 (0)