@@ -23,8 +23,9 @@ import scala.collection.mutable.ArrayBuffer
23
23
import org .apache .spark .sql .AnalysisException
24
24
import org .apache .spark .sql .catalyst .expressions .{Expression , PlanExpression }
25
25
import org .apache .spark .sql .catalyst .plans .QueryPlan
26
+ import org .apache .spark .sql .execution .adaptive .{AdaptiveSparkPlanExec , AdaptiveSparkPlanHelper , QueryStageExec }
26
27
27
- object ExplainUtils {
28
+ object ExplainUtils extends AdaptiveSparkPlanHelper {
28
29
/**
29
30
* Given a input physical plan, performs the following tasks.
30
31
* 1. Computes the operator id for current operator and records it in the operaror
@@ -144,15 +145,26 @@ object ExplainUtils {
144
145
case p : WholeStageCodegenExec =>
145
146
case p : InputAdapter =>
146
147
case other : QueryPlan [_] =>
147
- if (! other.getTagValue(QueryPlan .OP_ID_TAG ).isDefined) {
148
+
149
+ def setOpId (): Unit = if (other.getTagValue(QueryPlan .OP_ID_TAG ).isEmpty) {
148
150
currentOperationID += 1
149
151
other.setTagValue(QueryPlan .OP_ID_TAG , currentOperationID)
150
152
operatorIDs += ((currentOperationID, other))
151
153
}
152
- other.innerChildren.foreach { plan =>
153
- currentOperationID = generateOperatorIDs(plan,
154
- currentOperationID,
155
- operatorIDs)
154
+
155
+ other match {
156
+ case p : AdaptiveSparkPlanExec =>
157
+ currentOperationID =
158
+ generateOperatorIDs(p.executedPlan, currentOperationID, operatorIDs)
159
+ setOpId()
160
+ case p : QueryStageExec =>
161
+ currentOperationID = generateOperatorIDs(p.plan, currentOperationID, operatorIDs)
162
+ setOpId()
163
+ case _ =>
164
+ setOpId()
165
+ other.innerChildren.foldLeft(currentOperationID) {
166
+ (curId, plan) => generateOperatorIDs(plan, curId, operatorIDs)
167
+ }
156
168
}
157
169
}
158
170
currentOperationID
@@ -163,21 +175,25 @@ object ExplainUtils {
163
175
* whole stage code gen id in the plan via setting a tag.
164
176
*/
165
177
private def generateWholeStageCodegenIds (plan : QueryPlan [_]): Unit = {
178
+ var currentCodegenId = - 1
179
+
180
+ def setCodegenId (p : QueryPlan [_], children : Seq [QueryPlan [_]]): Unit = {
181
+ if (currentCodegenId != - 1 ) {
182
+ p.setTagValue(QueryPlan .CODEGEN_ID_TAG , currentCodegenId)
183
+ }
184
+ children.foreach(generateWholeStageCodegenIds)
185
+ }
186
+
166
187
// Skip the subqueries as they are not printed as part of main query block.
167
188
if (plan.isInstanceOf [BaseSubqueryExec ]) {
168
189
return
169
190
}
170
- var currentCodegenId = - 1
171
191
plan.foreach {
172
192
case p : WholeStageCodegenExec => currentCodegenId = p.codegenStageId
173
193
case _ : InputAdapter => currentCodegenId = - 1
174
- case other : QueryPlan [_] =>
175
- if (currentCodegenId != - 1 ) {
176
- other.setTagValue(QueryPlan .CODEGEN_ID_TAG , currentCodegenId)
177
- }
178
- other.innerChildren.foreach { plan =>
179
- generateWholeStageCodegenIds(plan)
180
- }
194
+ case p : AdaptiveSparkPlanExec => setCodegenId(p, Seq (p.executedPlan))
195
+ case p : QueryStageExec => setCodegenId(p, Seq (p.plan))
196
+ case other : QueryPlan [_] => setCodegenId(other, other.innerChildren)
181
197
}
182
198
}
183
199
@@ -232,13 +248,16 @@ object ExplainUtils {
232
248
}
233
249
234
250
def removeTags (plan : QueryPlan [_]): Unit = {
251
+ def remove (p : QueryPlan [_], children : Seq [QueryPlan [_]]): Unit = {
252
+ p.unsetTagValue(QueryPlan .OP_ID_TAG )
253
+ p.unsetTagValue(QueryPlan .CODEGEN_ID_TAG )
254
+ children.foreach(removeTags)
255
+ }
256
+
235
257
plan foreach {
236
- case plan : QueryPlan [_] =>
237
- plan.unsetTagValue(QueryPlan .OP_ID_TAG )
238
- plan.unsetTagValue(QueryPlan .CODEGEN_ID_TAG )
239
- plan.innerChildren.foreach { p =>
240
- removeTags(p)
241
- }
258
+ case p : AdaptiveSparkPlanExec => remove(p, Seq (p.executedPlan))
259
+ case p : QueryStageExec => remove(p, Seq (p.plan))
260
+ case plan : QueryPlan [_] => remove(plan, plan.innerChildren)
242
261
}
243
262
}
244
263
}
0 commit comments