Skip to content

Commit 0f72e4f

Browse files
Davies Liudavies
authored andcommitted
[SPARK-16958] [SQL] Reuse subqueries within the same query
## What changes were proposed in this pull request? There could be multiple subqueries that generate same results, we could re-use the result instead of running it multiple times. This PR also cleanup up how we run subqueries. For SQL query ```sql select id,(select avg(id) from t) from t where id > (select avg(id) from t) ``` The explain is ``` == Physical Plan == *Project [id#15L, Subquery subquery29 AS scalarsubquery()#35] : +- Subquery subquery29 : +- *HashAggregate(keys=[], functions=[avg(id#15L)]) : +- Exchange SinglePartition : +- *HashAggregate(keys=[], functions=[partial_avg(id#15L)]) : +- *Range (0, 1000, splits=4) +- *Filter (cast(id#15L as double) > Subquery subquery29) : +- Subquery subquery29 : +- *HashAggregate(keys=[], functions=[avg(id#15L)]) : +- Exchange SinglePartition : +- *HashAggregate(keys=[], functions=[partial_avg(id#15L)]) : +- *Range (0, 1000, splits=4) +- *Range (0, 1000, splits=4) ``` The visualized plan: ![reuse-subquery](https://cloud.githubusercontent.com/assets/40902/17573229/e578d93c-5f0d-11e6-8a3c-0150d81d3aed.png) ## How was this patch tested? Existing tests. Author: Davies Liu <davies@databricks.com> Closes #14548 from davies/subq.
1 parent 4d49680 commit 0f72e4f

File tree

7 files changed

+215
-49
lines changed

7 files changed

+215
-49
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/subquery.scala

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,13 @@ case class PredicateSubquery(
102102
override def nullable: Boolean = nullAware
103103
override def plan: LogicalPlan = SubqueryAlias(toString, query)
104104
override def withNewPlan(plan: LogicalPlan): PredicateSubquery = copy(query = plan)
105+
override def semanticEquals(o: Expression): Boolean = o match {
106+
case p: PredicateSubquery =>
107+
query.sameResult(p.query) && nullAware == p.nullAware &&
108+
children.length == p.children.length &&
109+
children.zip(p.children).forall(p => p._1.semanticEquals(p._2))
110+
case _ => false
111+
}
105112
override def toString: String = s"predicate-subquery#${exprId.id} $conditionString"
106113
}
107114

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -538,9 +538,9 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] extends Product {
538538

539539
if (innerChildren.nonEmpty) {
540540
innerChildren.init.foreach(_.generateTreeString(
541-
depth + 2, lastChildren :+ false :+ false, builder, verbose))
541+
depth + 2, lastChildren :+ children.isEmpty :+ false, builder, verbose))
542542
innerChildren.last.generateTreeString(
543-
depth + 2, lastChildren :+ false :+ true, builder, verbose)
543+
depth + 2, lastChildren :+ children.isEmpty :+ true, builder, verbose)
544544
}
545545

546546
if (children.nonEmpty) {

sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,8 @@ class QueryExecution(val sparkSession: SparkSession, val logical: LogicalPlan) {
101101
PlanSubqueries(sparkSession),
102102
EnsureRequirements(sparkSession.sessionState.conf),
103103
CollapseCodegenStages(sparkSession.sessionState.conf),
104-
ReuseExchange(sparkSession.sessionState.conf))
104+
ReuseExchange(sparkSession.sessionState.conf),
105+
ReuseSubquery(sparkSession.sessionState.conf))
105106

106107
protected def stringOrError[A](f: => A): String =
107108
try f.toString catch { case e: Throwable => e.toString }

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala

Lines changed: 10 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -142,21 +142,18 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
142142
* This list is populated by [[prepareSubqueries]], which is called in [[prepare]].
143143
*/
144144
@transient
145-
private val subqueryResults = new ArrayBuffer[(ScalarSubquery, Future[Array[InternalRow]])]
145+
private val runningSubqueries = new ArrayBuffer[ExecSubqueryExpression]
146146

147147
/**
148148
* Finds scalar subquery expressions in this plan node and starts evaluating them.
149-
* The list of subqueries are added to [[subqueryResults]].
150149
*/
151150
protected def prepareSubqueries(): Unit = {
152-
val allSubqueries = expressions.flatMap(_.collect {case e: ScalarSubquery => e})
153-
allSubqueries.asInstanceOf[Seq[ScalarSubquery]].foreach { e =>
154-
val futureResult = Future {
155-
// Each subquery should return only one row (and one column). We take two here and throws
156-
// an exception later if the number of rows is greater than one.
157-
e.executedPlan.executeTake(2)
158-
}(SparkPlan.subqueryExecutionContext)
159-
subqueryResults += e -> futureResult
151+
expressions.foreach {
152+
_.collect {
153+
case e: ExecSubqueryExpression =>
154+
e.plan.prepare()
155+
runningSubqueries += e
156+
}
160157
}
161158
}
162159

@@ -165,21 +162,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
165162
*/
166163
protected def waitForSubqueries(): Unit = synchronized {
167164
// fill in the result of subqueries
168-
subqueryResults.foreach { case (e, futureResult) =>
169-
val rows = ThreadUtils.awaitResult(futureResult, Duration.Inf)
170-
if (rows.length > 1) {
171-
sys.error(s"more than one row returned by a subquery used as an expression:\n${e.plan}")
172-
}
173-
if (rows.length == 1) {
174-
assert(rows(0).numFields == 1,
175-
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
176-
e.updateResult(rows(0).get(0, e.dataType))
177-
} else {
178-
// If there is no rows returned, the result should be null.
179-
e.updateResult(null)
180-
}
165+
runningSubqueries.foreach { sub =>
166+
sub.updateResult()
181167
}
182-
subqueryResults.clear()
168+
runningSubqueries.clear()
183169
}
184170

185171
/**

sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala

Lines changed: 59 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,19 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import scala.concurrent.{ExecutionContext, Future}
21+
import scala.concurrent.duration.Duration
22+
23+
import org.apache.spark.SparkException
2024
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
2125
import org.apache.spark.sql.catalyst.InternalRow
2226
import org.apache.spark.sql.catalyst.expressions._
2327
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
2428
import org.apache.spark.sql.catalyst.plans.physical._
2529
import org.apache.spark.sql.execution.metric.SQLMetrics
26-
import org.apache.spark.sql.types.{LongType, StructField, StructType}
30+
import org.apache.spark.sql.execution.ui.SparkListenerDriverAccumUpdates
31+
import org.apache.spark.sql.types.LongType
32+
import org.apache.spark.util.ThreadUtils
2733
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
2834

2935
/** Physical plan for Project. */
@@ -502,15 +508,64 @@ case class OutputFakerExec(output: Seq[Attribute], child: SparkPlan) extends Spa
502508

503509
/**
504510
* Physical plan for a subquery.
505-
*
506-
* This is used to generate tree string for SparkScalarSubquery.
507511
*/
508512
case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
513+
514+
override lazy val metrics = Map(
515+
"dataSize" -> SQLMetrics.createMetric(sparkContext, "data size (bytes)"),
516+
"collectTime" -> SQLMetrics.createMetric(sparkContext, "time to collect (ms)"))
517+
509518
override def output: Seq[Attribute] = child.output
510519
override def outputPartitioning: Partitioning = child.outputPartitioning
511520
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
512521

522+
override def sameResult(o: SparkPlan): Boolean = o match {
523+
case s: SubqueryExec => child.sameResult(s.child)
524+
case _ => false
525+
}
526+
527+
@transient
528+
private lazy val relationFuture: Future[Array[InternalRow]] = {
529+
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
530+
val executionId = sparkContext.getLocalProperty(SQLExecution.EXECUTION_ID_KEY)
531+
Future {
532+
// This will run in another thread. Set the execution id so that we can connect these jobs
533+
// with the correct execution.
534+
SQLExecution.withExecutionId(sparkContext, executionId) {
535+
val beforeCollect = System.nanoTime()
536+
// Note that we use .executeCollect() because we don't want to convert data to Scala types
537+
val rows: Array[InternalRow] = child.executeCollect()
538+
val beforeBuild = System.nanoTime()
539+
longMetric("collectTime") += (beforeBuild - beforeCollect) / 1000000
540+
val dataSize = rows.map(_.asInstanceOf[UnsafeRow].getSizeInBytes.toLong).sum
541+
longMetric("dataSize") += dataSize
542+
543+
// There are some cases we don't care about the metrics and call `SparkPlan.doExecute`
544+
// directly without setting an execution id. We should be tolerant to it.
545+
if (executionId != null) {
546+
sparkContext.listenerBus.post(SparkListenerDriverAccumUpdates(
547+
executionId.toLong, metrics.values.map(m => m.id -> m.value).toSeq))
548+
}
549+
550+
rows
551+
}
552+
}(SubqueryExec.executionContext)
553+
}
554+
555+
protected override def doPrepare(): Unit = {
556+
relationFuture
557+
}
558+
513559
protected override def doExecute(): RDD[InternalRow] = {
514-
throw new UnsupportedOperationException
560+
child.execute()
515561
}
562+
563+
override def executeCollect(): Array[InternalRow] = {
564+
ThreadUtils.awaitResult(relationFuture, Duration.Inf)
565+
}
566+
}
567+
568+
object SubqueryExec {
569+
private[execution] val executionContext = ExecutionContext.fromExecutorService(
570+
ThreadUtils.newDaemonCachedThreadPool("subquery", 16))
516571
}

sql/core/src/main/scala/org/apache/spark/sql/execution/subquery.scala

Lines changed: 129 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -17,42 +17,78 @@
1717

1818
package org.apache.spark.sql.execution
1919

20+
import scala.collection.mutable
21+
import scala.collection.mutable.ArrayBuffer
22+
2023
import org.apache.spark.sql.SparkSession
21-
import org.apache.spark.sql.catalyst.expressions
22-
import org.apache.spark.sql.catalyst.InternalRow
23-
import org.apache.spark.sql.catalyst.expressions.{Expression, ExprId, Literal, SubqueryExpression}
24+
import org.apache.spark.sql.catalyst.{expressions, InternalRow}
25+
import org.apache.spark.sql.catalyst.expressions._
2426
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode}
2527
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
2628
import org.apache.spark.sql.catalyst.rules.Rule
27-
import org.apache.spark.sql.types.DataType
29+
import org.apache.spark.sql.internal.SQLConf
30+
import org.apache.spark.sql.types.{BooleanType, DataType, StructType}
31+
32+
/**
33+
* The base class for subquery that is used in SparkPlan.
34+
*/
35+
trait ExecSubqueryExpression extends SubqueryExpression {
36+
37+
val executedPlan: SubqueryExec
38+
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression
39+
40+
// does not have logical plan
41+
override def query: LogicalPlan = throw new UnsupportedOperationException
42+
override def withNewPlan(plan: LogicalPlan): SubqueryExpression =
43+
throw new UnsupportedOperationException
44+
45+
override def plan: SparkPlan = executedPlan
46+
47+
/**
48+
* Fill the expression with collected result from executed plan.
49+
*/
50+
def updateResult(): Unit
51+
}
2852

2953
/**
3054
* A subquery that will return only one row and one column.
3155
*
3256
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
3357
*/
3458
case class ScalarSubquery(
35-
executedPlan: SparkPlan,
59+
executedPlan: SubqueryExec,
3660
exprId: ExprId)
37-
extends SubqueryExpression {
38-
39-
override def query: LogicalPlan = throw new UnsupportedOperationException
40-
override def withNewPlan(plan: LogicalPlan): SubqueryExpression = {
41-
throw new UnsupportedOperationException
42-
}
43-
override def plan: SparkPlan = SubqueryExec(simpleString, executedPlan)
61+
extends ExecSubqueryExpression {
4462

4563
override def dataType: DataType = executedPlan.schema.fields.head.dataType
4664
override def children: Seq[Expression] = Nil
4765
override def nullable: Boolean = true
48-
override def toString: String = s"subquery#${exprId.id}"
66+
override def toString: String = executedPlan.simpleString
67+
68+
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
69+
70+
override def semanticEquals(other: Expression): Boolean = other match {
71+
case s: ScalarSubquery => executedPlan.sameResult(executedPlan)
72+
case _ => false
73+
}
4974

5075
// the first column in first row from `query`.
5176
@volatile private var result: Any = null
5277
@volatile private var updated: Boolean = false
5378

54-
def updateResult(v: Any): Unit = {
55-
result = v
79+
def updateResult(): Unit = {
80+
val rows = plan.executeCollect()
81+
if (rows.length > 1) {
82+
sys.error(s"more than one row returned by a subquery used as an expression:\n${plan}")
83+
}
84+
if (rows.length == 1) {
85+
assert(rows(0).numFields == 1,
86+
s"Expects 1 field, but got ${rows(0).numFields}; something went wrong in analysis")
87+
result = rows(0).get(0, dataType)
88+
} else {
89+
// If there is no rows returned, the result should be null.
90+
result = null
91+
}
5692
updated = true
5793
}
5894

@@ -67,6 +103,51 @@ case class ScalarSubquery(
67103
}
68104
}
69105

106+
/**
107+
* A subquery that will check the value of `child` whether is in the result of a query or not.
108+
*/
109+
case class InSubquery(
110+
child: Expression,
111+
executedPlan: SubqueryExec,
112+
exprId: ExprId,
113+
private var result: Array[Any] = null,
114+
private var updated: Boolean = false) extends ExecSubqueryExpression {
115+
116+
override def dataType: DataType = BooleanType
117+
override def children: Seq[Expression] = child :: Nil
118+
override def nullable: Boolean = child.nullable
119+
override def toString: String = s"$child IN ${executedPlan.name}"
120+
121+
def withExecutedPlan(plan: SubqueryExec): ExecSubqueryExpression = copy(executedPlan = plan)
122+
123+
override def semanticEquals(other: Expression): Boolean = other match {
124+
case in: InSubquery => child.semanticEquals(in.child) &&
125+
executedPlan.sameResult(in.executedPlan)
126+
case _ => false
127+
}
128+
129+
def updateResult(): Unit = {
130+
val rows = plan.executeCollect()
131+
result = rows.map(_.get(0, child.dataType)).asInstanceOf[Array[Any]]
132+
updated = true
133+
}
134+
135+
override def eval(input: InternalRow): Any = {
136+
require(updated, s"$this has not finished")
137+
val v = child.eval(input)
138+
if (v == null) {
139+
null
140+
} else {
141+
result.contains(v)
142+
}
143+
}
144+
145+
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
146+
require(updated, s"$this has not finished")
147+
InSet(child, result.toSet).doGenCode(ctx, ev)
148+
}
149+
}
150+
70151
/**
71152
* Plans scalar subqueries from that are present in the given [[SparkPlan]].
72153
*/
@@ -75,7 +156,39 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
75156
plan.transformAllExpressions {
76157
case subquery: expressions.ScalarSubquery =>
77158
val executedPlan = new QueryExecution(sparkSession, subquery.plan).executedPlan
78-
ScalarSubquery(executedPlan, subquery.exprId)
159+
ScalarSubquery(
160+
SubqueryExec(s"subquery${subquery.exprId.id}", executedPlan),
161+
subquery.exprId)
162+
case expressions.PredicateSubquery(plan, Seq(e: Expression), _, exprId) =>
163+
val executedPlan = new QueryExecution(sparkSession, plan).executedPlan
164+
InSubquery(e, SubqueryExec(s"subquery${exprId.id}", executedPlan), exprId)
165+
}
166+
}
167+
}
168+
169+
170+
/**
171+
* Find out duplicated exchanges in the spark plan, then use the same exchange for all the
172+
* references.
173+
*/
174+
case class ReuseSubquery(conf: SQLConf) extends Rule[SparkPlan] {
175+
176+
def apply(plan: SparkPlan): SparkPlan = {
177+
if (!conf.exchangeReuseEnabled) {
178+
return plan
179+
}
180+
// Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
181+
val subqueries = mutable.HashMap[StructType, ArrayBuffer[SubqueryExec]]()
182+
plan transformAllExpressions {
183+
case sub: ExecSubqueryExpression =>
184+
val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer[SubqueryExec]())
185+
val sameResult = sameSchema.find(_.sameResult(sub.plan))
186+
if (sameResult.isDefined) {
187+
sub.withExecutedPlan(sameResult.get)
188+
} else {
189+
sameSchema += sub.executedPlan
190+
sub
191+
}
79192
}
80193
}
81194
}

sql/core/src/main/scala/org/apache/spark/sql/execution/ui/SparkPlanGraph.scala

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,11 @@ object SparkPlanGraph {
9999
case "Subquery" if subgraph != null =>
100100
// Subquery should not be included in WholeStageCodegen
101101
buildSparkPlanGraphNode(planInfo, nodeIdGenerator, nodes, edges, parent, null, exchanges)
102-
case "ReusedExchange" =>
102+
case "Subquery" if exchanges.contains(planInfo) =>
103+
// Point to the re-used subquery
104+
val node = exchanges(planInfo)
105+
edges += SparkPlanGraphEdge(node.id, parent.id)
106+
case "ReusedExchange" if exchanges.contains(planInfo.children.head) =>
103107
// Point to the re-used exchange
104108
val node = exchanges(planInfo.children.head)
105109
edges += SparkPlanGraphEdge(node.id, parent.id)
@@ -115,7 +119,7 @@ object SparkPlanGraph {
115119
} else {
116120
subgraph.nodes += node
117121
}
118-
if (name.contains("Exchange")) {
122+
if (name.contains("Exchange") || name == "Subquery") {
119123
exchanges += planInfo -> node
120124
}
121125

0 commit comments

Comments
 (0)