Skip to content

Commit c56a012

Browse files
maryannxuecloud-fan
authored andcommitted
[SPARK-29060][SQL] Add tree traversal helper for adaptive spark plans
### What changes were proposed in this pull request? This PR adds a utility class `AdaptiveSparkPlanHelper` which provides methods related to tree traversal of an `AdaptiveSparkPlanExec` plan. Unlike their counterparts in `TreeNode` or `QueryPlan`, these methods traverse down leaf nodes of adaptive plans, i.e., `AdaptiveSparkPlanExec` and `QueryStageExec`. ### Why are the changes needed? This utility class can greatly simplify tree traversal code for adaptive spark plans. ### Does this PR introduce any user-facing change? No. ### How was this patch tested? Refined `AdaptiveQueryExecSuite` with the help of the new utility methods. Closes #25764 from maryannxue/aqe-utils. Authored-by: maryannxue <maryannxue@apache.org> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent 8e9fafb commit c56a012

File tree

2 files changed

+147
-21
lines changed

2 files changed

+147
-21
lines changed
Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.execution.adaptive
19+
20+
import org.apache.spark.sql.execution.SparkPlan
21+
22+
/**
23+
* This class provides utility methods related to tree traversal of an [[AdaptiveSparkPlanExec]]
24+
* plan. Unlike their counterparts in [[org.apache.spark.sql.catalyst.trees.TreeNode]] or
25+
* [[org.apache.spark.sql.catalyst.plans.QueryPlan]], these methods traverse down leaf nodes of
26+
* adaptive plans, i.e., [[AdaptiveSparkPlanExec]] and [[QueryStageExec]].
27+
*/
28+
trait AdaptiveSparkPlanHelper {
29+
30+
/**
31+
* Find the first [[SparkPlan]] that satisfies the condition specified by `f`.
32+
* The condition is recursively applied to this node and all of its children (pre-order).
33+
*/
34+
def find(p: SparkPlan)(f: SparkPlan => Boolean): Option[SparkPlan] = if (f(p)) {
35+
Some(p)
36+
} else {
37+
allChildren(p).foldLeft(Option.empty[SparkPlan]) { (l, r) => l.orElse(find(r)(f)) }
38+
}
39+
40+
/**
41+
* Runs the given function on this node and then recursively on children.
42+
* @param f the function to be applied to each node in the tree.
43+
*/
44+
def foreach(p: SparkPlan)(f: SparkPlan => Unit): Unit = {
45+
f(p)
46+
allChildren(p).foreach(foreach(_)(f))
47+
}
48+
49+
/**
50+
* Runs the given function recursively on children then on this node.
51+
* @param f the function to be applied to each node in the tree.
52+
*/
53+
def foreachUp(p: SparkPlan)(f: SparkPlan => Unit): Unit = {
54+
allChildren(p).foreach(foreachUp(_)(f))
55+
f(p)
56+
}
57+
58+
/**
59+
* Returns a Seq containing the result of applying the given function to each
60+
* node in this tree in a preorder traversal.
61+
* @param f the function to be applied.
62+
*/
63+
def map[A](p: SparkPlan)(f: SparkPlan => A): Seq[A] = {
64+
val ret = new collection.mutable.ArrayBuffer[A]()
65+
foreach(p)(ret += f(_))
66+
ret
67+
}
68+
69+
/**
70+
* Returns a Seq by applying a function to all nodes in this tree and using the elements of the
71+
* resulting collections.
72+
*/
73+
def flatMap[A](p: SparkPlan)(f: SparkPlan => TraversableOnce[A]): Seq[A] = {
74+
val ret = new collection.mutable.ArrayBuffer[A]()
75+
foreach(p)(ret ++= f(_))
76+
ret
77+
}
78+
79+
/**
80+
* Returns a Seq containing the result of applying a partial function to all elements in this
81+
* tree on which the function is defined.
82+
*/
83+
def collect[B](p: SparkPlan)(pf: PartialFunction[SparkPlan, B]): Seq[B] = {
84+
val ret = new collection.mutable.ArrayBuffer[B]()
85+
val lifted = pf.lift
86+
foreach(p)(node => lifted(node).foreach(ret.+=))
87+
ret
88+
}
89+
90+
/**
91+
* Returns a Seq containing the leaves in this tree.
92+
*/
93+
def collectLeaves(p: SparkPlan): Seq[SparkPlan] = {
94+
collect(p) { case plan if allChildren(plan).isEmpty => plan }
95+
}
96+
97+
/**
98+
* Finds and returns the first [[SparkPlan]] of the tree for which the given partial function
99+
* is defined (pre-order), and applies the partial function to it.
100+
*/
101+
def collectFirst[B](p: SparkPlan)(pf: PartialFunction[SparkPlan, B]): Option[B] = {
102+
val lifted = pf.lift
103+
lifted(p).orElse {
104+
allChildren(p).foldLeft(Option.empty[B]) { (l, r) => l.orElse(collectFirst(r)(pf)) }
105+
}
106+
}
107+
108+
/**
109+
* Returns a sequence containing the result of applying a partial function to all elements in this
110+
* plan, also considering all the plans in its (nested) subqueries
111+
*/
112+
def collectInPlanAndSubqueries[B](p: SparkPlan)(f: PartialFunction[SparkPlan, B]): Seq[B] = {
113+
(p +: subqueriesAll(p)).flatMap(collect(_)(f))
114+
}
115+
116+
/**
117+
* Returns a sequence containing the subqueries in this plan, also including the (nested)
118+
* subquries in its children
119+
*/
120+
def subqueriesAll(p: SparkPlan): Seq[SparkPlan] = {
121+
val subqueries = flatMap(p)(_.subqueries)
122+
subqueries ++ subqueries.flatMap(subqueriesAll)
123+
}
124+
125+
private def allChildren(p: SparkPlan): Seq[SparkPlan] = p match {
126+
case a: AdaptiveSparkPlanExec => Seq(a.executedPlan)
127+
case s: QueryStageExec => Seq(s.plan)
128+
case _ => p.children
129+
}
130+
}

sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -20,13 +20,16 @@ package org.apache.spark.sql.execution.adaptive
2020
import org.apache.spark.sql.QueryTest
2121
import org.apache.spark.sql.execution.{ReusedSubqueryExec, SparkPlan}
2222
import org.apache.spark.sql.execution.adaptive.rule.CoalescedShuffleReaderExec
23-
import org.apache.spark.sql.execution.columnar.InMemoryTableScanExec
2423
import org.apache.spark.sql.execution.exchange.Exchange
2524
import org.apache.spark.sql.execution.joins.{BroadcastHashJoinExec, BuildRight, SortMergeJoinExec}
2625
import org.apache.spark.sql.internal.SQLConf
2726
import org.apache.spark.sql.test.SharedSparkSession
2827

29-
class AdaptiveQueryExecSuite extends QueryTest with SharedSparkSession {
28+
class AdaptiveQueryExecSuite
29+
extends QueryTest
30+
with SharedSparkSession
31+
with AdaptiveSparkPlanHelper {
32+
3033
import testImplicits._
3134

3235
setupTestData()
@@ -51,34 +54,27 @@ class AdaptiveQueryExecSuite extends QueryTest with SharedSparkSession {
5154
}
5255

5356
private def findTopLevelBroadcastHashJoin(plan: SparkPlan): Seq[BroadcastHashJoinExec] = {
54-
plan.collect {
55-
case j: BroadcastHashJoinExec => Seq(j)
56-
case s: QueryStageExec => findTopLevelBroadcastHashJoin(s.plan)
57-
}.flatten
57+
collect(plan) {
58+
case j: BroadcastHashJoinExec => j
59+
}
5860
}
5961

6062
private def findTopLevelSortMergeJoin(plan: SparkPlan): Seq[SortMergeJoinExec] = {
61-
plan.collect {
62-
case j: SortMergeJoinExec => Seq(j)
63-
case s: QueryStageExec => findTopLevelSortMergeJoin(s.plan)
64-
}.flatten
63+
collect(plan) {
64+
case j: SortMergeJoinExec => j
65+
}
6566
}
6667

6768
private def findReusedExchange(plan: SparkPlan): Seq[ReusedQueryStageExec] = {
68-
plan.collect {
69-
case e: ReusedQueryStageExec => Seq(e)
70-
case a: AdaptiveSparkPlanExec => findReusedExchange(a.executedPlan)
71-
case s: QueryStageExec => findReusedExchange(s.plan)
72-
case p: SparkPlan => p.subqueries.flatMap(findReusedExchange)
73-
}.flatten
69+
collectInPlanAndSubqueries(plan) {
70+
case e: ReusedQueryStageExec => e
71+
}
7472
}
7573

7674
private def findReusedSubquery(plan: SparkPlan): Seq[ReusedSubqueryExec] = {
77-
plan.collect {
78-
case e: ReusedSubqueryExec => Seq(e)
79-
case s: QueryStageExec => findReusedSubquery(s.plan)
80-
case p: SparkPlan => p.subqueries.flatMap(findReusedSubquery)
81-
}.flatten
75+
collectInPlanAndSubqueries(plan) {
76+
case e: ReusedSubqueryExec => e
77+
}
8278
}
8379

8480
test("Change merge join to broadcast join") {

0 commit comments

Comments
 (0)