17
17
18
18
package org .apache .spark .sql .execution
19
19
20
+ import scala .collection .mutable
21
+ import scala .collection .mutable .ArrayBuffer
22
+
20
23
import org .apache .spark .sql .SparkSession
21
- import org .apache .spark .sql .catalyst .expressions
22
- import org .apache .spark .sql .catalyst .InternalRow
23
- import org .apache .spark .sql .catalyst .expressions .{Expression , ExprId , Literal , SubqueryExpression }
24
+ import org .apache .spark .sql .catalyst .{expressions , InternalRow }
25
+ import org .apache .spark .sql .catalyst .expressions ._
24
26
import org .apache .spark .sql .catalyst .expressions .codegen .{CodegenContext , ExprCode }
25
27
import org .apache .spark .sql .catalyst .plans .logical .LogicalPlan
26
28
import org .apache .spark .sql .catalyst .rules .Rule
27
- import org .apache .spark .sql .types .DataType
29
+ import org .apache .spark .sql .internal .SQLConf
30
+ import org .apache .spark .sql .types .{BooleanType , DataType , StructType }
31
+
32
+ /**
33
+ * The base class for subquery that is used in SparkPlan.
34
+ */
35
+ trait ExecSubqueryExpression extends SubqueryExpression {
36
+
37
+ val executedPlan : SubqueryExec
38
+ def withExecutedPlan (plan : SubqueryExec ): ExecSubqueryExpression
39
+
40
+ // does not have logical plan
41
+ override def query : LogicalPlan = throw new UnsupportedOperationException
42
+ override def withNewPlan (plan : LogicalPlan ): SubqueryExpression =
43
+ throw new UnsupportedOperationException
44
+
45
+ override def plan : SparkPlan = executedPlan
46
+
47
+ /**
48
+ * Fill the expression with collected result from executed plan.
49
+ */
50
+ def updateResult (): Unit
51
+ }
28
52
29
53
/**
30
54
* A subquery that will return only one row and one column.
31
55
*
32
56
* This is the physical copy of ScalarSubquery to be used inside SparkPlan.
33
57
*/
34
58
case class ScalarSubquery (
35
- executedPlan : SparkPlan ,
59
+ executedPlan : SubqueryExec ,
36
60
exprId : ExprId )
37
- extends SubqueryExpression {
38
-
39
- override def query : LogicalPlan = throw new UnsupportedOperationException
40
- override def withNewPlan (plan : LogicalPlan ): SubqueryExpression = {
41
- throw new UnsupportedOperationException
42
- }
43
- override def plan : SparkPlan = SubqueryExec (simpleString, executedPlan)
61
+ extends ExecSubqueryExpression {
44
62
45
63
override def dataType : DataType = executedPlan.schema.fields.head.dataType
46
64
override def children : Seq [Expression ] = Nil
47
65
override def nullable : Boolean = true
48
- override def toString : String = s " subquery# ${exprId.id}"
66
+ override def toString : String = executedPlan.simpleString
67
+
68
+ def withExecutedPlan (plan : SubqueryExec ): ExecSubqueryExpression = copy(executedPlan = plan)
69
+
70
+ override def semanticEquals (other : Expression ): Boolean = other match {
71
+ case s : ScalarSubquery => executedPlan.sameResult(executedPlan)
72
+ case _ => false
73
+ }
49
74
50
75
// the first column in first row from `query`.
51
76
@ volatile private var result : Any = null
52
77
@ volatile private var updated : Boolean = false
53
78
54
- def updateResult (v : Any ): Unit = {
55
- result = v
79
+ def updateResult (): Unit = {
80
+ val rows = plan.executeCollect()
81
+ if (rows.length > 1 ) {
82
+ sys.error(s " more than one row returned by a subquery used as an expression: \n ${plan}" )
83
+ }
84
+ if (rows.length == 1 ) {
85
+ assert(rows(0 ).numFields == 1 ,
86
+ s " Expects 1 field, but got ${rows(0 ).numFields}; something went wrong in analysis " )
87
+ result = rows(0 ).get(0 , dataType)
88
+ } else {
89
+ // If there is no rows returned, the result should be null.
90
+ result = null
91
+ }
56
92
updated = true
57
93
}
58
94
@@ -67,6 +103,51 @@ case class ScalarSubquery(
67
103
}
68
104
}
69
105
106
+ /**
107
+ * A subquery that will check the value of `child` whether is in the result of a query or not.
108
+ */
109
+ case class InSubquery (
110
+ child : Expression ,
111
+ executedPlan : SubqueryExec ,
112
+ exprId : ExprId ,
113
+ private var result : Array [Any ] = null ,
114
+ private var updated : Boolean = false ) extends ExecSubqueryExpression {
115
+
116
+ override def dataType : DataType = BooleanType
117
+ override def children : Seq [Expression ] = child :: Nil
118
+ override def nullable : Boolean = child.nullable
119
+ override def toString : String = s " $child IN ${executedPlan.name}"
120
+
121
+ def withExecutedPlan (plan : SubqueryExec ): ExecSubqueryExpression = copy(executedPlan = plan)
122
+
123
+ override def semanticEquals (other : Expression ): Boolean = other match {
124
+ case in : InSubquery => child.semanticEquals(in.child) &&
125
+ executedPlan.sameResult(in.executedPlan)
126
+ case _ => false
127
+ }
128
+
129
+ def updateResult (): Unit = {
130
+ val rows = plan.executeCollect()
131
+ result = rows.map(_.get(0 , child.dataType)).asInstanceOf [Array [Any ]]
132
+ updated = true
133
+ }
134
+
135
+ override def eval (input : InternalRow ): Any = {
136
+ require(updated, s " $this has not finished " )
137
+ val v = child.eval(input)
138
+ if (v == null ) {
139
+ null
140
+ } else {
141
+ result.contains(v)
142
+ }
143
+ }
144
+
145
+ override def doGenCode (ctx : CodegenContext , ev : ExprCode ): ExprCode = {
146
+ require(updated, s " $this has not finished " )
147
+ InSet (child, result.toSet).doGenCode(ctx, ev)
148
+ }
149
+ }
150
+
70
151
/**
71
152
* Plans scalar subqueries from that are present in the given [[SparkPlan ]].
72
153
*/
@@ -75,7 +156,39 @@ case class PlanSubqueries(sparkSession: SparkSession) extends Rule[SparkPlan] {
75
156
plan.transformAllExpressions {
76
157
case subquery : expressions.ScalarSubquery =>
77
158
val executedPlan = new QueryExecution (sparkSession, subquery.plan).executedPlan
78
- ScalarSubquery (executedPlan, subquery.exprId)
159
+ ScalarSubquery (
160
+ SubqueryExec (s " subquery ${subquery.exprId.id}" , executedPlan),
161
+ subquery.exprId)
162
+ case expressions.PredicateSubquery (plan, Seq (e : Expression ), _, exprId) =>
163
+ val executedPlan = new QueryExecution (sparkSession, plan).executedPlan
164
+ InSubquery (e, SubqueryExec (s " subquery ${exprId.id}" , executedPlan), exprId)
165
+ }
166
+ }
167
+ }
168
+
169
+
170
+ /**
171
+ * Find out duplicated exchanges in the spark plan, then use the same exchange for all the
172
+ * references.
173
+ */
174
+ case class ReuseSubquery (conf : SQLConf ) extends Rule [SparkPlan ] {
175
+
176
+ def apply (plan : SparkPlan ): SparkPlan = {
177
+ if (! conf.exchangeReuseEnabled) {
178
+ return plan
179
+ }
180
+ // Build a hash map using schema of exchanges to avoid O(N*N) sameResult calls.
181
+ val subqueries = mutable.HashMap [StructType , ArrayBuffer [SubqueryExec ]]()
182
+ plan transformAllExpressions {
183
+ case sub : ExecSubqueryExpression =>
184
+ val sameSchema = subqueries.getOrElseUpdate(sub.plan.schema, ArrayBuffer [SubqueryExec ]())
185
+ val sameResult = sameSchema.find(_.sameResult(sub.plan))
186
+ if (sameResult.isDefined) {
187
+ sub.withExecutedPlan(sameResult.get)
188
+ } else {
189
+ sameSchema += sub.executedPlan
190
+ sub
191
+ }
79
192
}
80
193
}
81
194
}
0 commit comments