Skip to content

Commit e4bfc22

Browse files
committed
fix test
1 parent 666bf76 commit e4bfc22

File tree

5 files changed

+26
-11
lines changed

5 files changed

+26
-11
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlanInfo.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ private[execution] object SparkPlanInfo {
5454
val children = plan match {
5555
case ReusedExchangeExec(_, child) => child :: Nil
5656
case a: AdaptiveSparkPlanExec => a.finalPlan.plan :: Nil
57-
case stage: QueryFragmentExec => stage.plan :: Nil
57+
case fragment: QueryFragmentExec => fragment.plan :: Nil
5858
case _ => plan.children ++ plan.subqueries
5959
}
6060
val metrics = plan.metrics.toSeq.map { case (key, metric) =>

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/QueryFragmentExec.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ case class BroadcastQueryFragmentExec(id: Int, plan: BroadcastExchangeExec)
139139
* A wrapper of QueryFragment to indicate that it's reused. Note that this is not a query fragment.
140140
*/
141141
case class ReusedQueryFragmentExec(child: QueryFragmentExec, output: Seq[Attribute])
142-
extends UnaryExecNode {
142+
extends LeafExecNode {
143143

144144
// Ignore this wrapper for canonicalizing.
145145
override def doCanonicalize(): SparkPlan = child.canonicalized

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/AssertChildFragmentsMaterialized.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ import org.apache.spark.sql.catalyst.rules.Rule
2121
import org.apache.spark.sql.execution.SparkPlan
2222
import org.apache.spark.sql.execution.adaptive.QueryFragmentExec
2323

24-
// A sanity check rule to make sure we are running query stage optimizer rules on a sub-tree of
25-
// query plan with all input stages materialized.
24+
// A sanity check rule to make sure we are running query fragment optimizer rules on a sub-tree of
25+
// query plan with all input fragments materialized.
2626
object AssertChildFragmentsMaterialized extends Rule[SparkPlan] {
2727
override def apply(plan: SparkPlan): SparkPlan = plan.transform {
2828
case q: QueryFragmentExec if !q.materialize().isCompleted =>

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/rule/ReduceNumShufflePartitions.scala

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,17 +17,17 @@
1717

1818
package org.apache.spark.sql.execution.adaptive.rule
1919

20-
import scala.collection.mutable.ArrayBuffer
2120
import scala.concurrent.duration.Duration
2221

2322
import org.apache.spark.MapOutputStatistics
23+
import org.apache.spark.SparkException
2424
import org.apache.spark.rdd.RDD
2525
import org.apache.spark.sql.catalyst.InternalRow
2626
import org.apache.spark.sql.catalyst.expressions.Attribute
2727
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
2828
import org.apache.spark.sql.catalyst.rules.Rule
2929
import org.apache.spark.sql.execution.{ShuffledRowRDD, SparkPlan, UnaryExecNode}
30-
import org.apache.spark.sql.execution.adaptive.{QueryFragmentExec, ShuffleQueryFragmentExec}
30+
import org.apache.spark.sql.execution.adaptive.{QueryFragmentExec, ReusedQueryFragmentExec, ShuffleQueryFragmentExec}
3131
import org.apache.spark.sql.internal.SQLConf
3232
import org.apache.spark.util.ThreadUtils
3333

@@ -61,7 +61,9 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
6161
ThreadUtils.awaitResult(metricsFuture, Duration.Zero)
6262
}
6363

64-
val allFragmentLeaves = plan.collectLeaves().forall(_.isInstanceOf[QueryFragmentExec])
64+
val allFragmentLeaves = plan.collectLeaves().forall { node =>
65+
node.isInstanceOf[QueryFragmentExec] || node.isInstanceOf[ReusedQueryFragmentExec]
66+
}
6567

6668
if (allFragmentLeaves) {
6769
// ShuffleQueryFragment gives null mapOutputStatistics when the input RDD has 0 partitions,
@@ -76,6 +78,8 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
7678
// number of output partitions.
7779
case fragment: ShuffleQueryFragmentExec =>
7880
CoalescedShuffleReaderExec(fragment, partitionStartIndices)
81+
case r@ReusedQueryFragmentExec(fragment: ShuffleQueryFragmentExec, output) =>
82+
CoalescedShuffleReaderExec(r, partitionStartIndices)
7983
}
8084
} else {
8185
plan
@@ -152,7 +156,9 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
152156
partitionStartIndices += i
153157
// reset postShuffleInputSize.
154158
postShuffleInputSize = nextShuffleInputSize
155-
} else postShuffleInputSize += nextShuffleInputSize
159+
} else {
160+
postShuffleInputSize += nextShuffleInputSize
161+
}
156162

157163
i += 1
158164
}
@@ -162,7 +168,7 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
162168
}
163169

164170
case class CoalescedShuffleReaderExec(
165-
child: ShuffleQueryFragmentExec,
171+
child: SparkPlan,
166172
partitionStartIndices: Array[Int]) extends UnaryExecNode {
167173

168174
override def output: Seq[Attribute] = child.output
@@ -175,7 +181,13 @@ case class CoalescedShuffleReaderExec(
175181

176182
override protected def doExecute(): RDD[InternalRow] = {
177183
if (cachedShuffleRDD == null) {
178-
cachedShuffleRDD = child.plan.createShuffledRDD(Some(partitionStartIndices))
184+
cachedShuffleRDD = child match {
185+
case fragment: ShuffleQueryFragmentExec =>
186+
fragment.plan.createShuffledRDD(Some(partitionStartIndices))
187+
case ReusedQueryFragmentExec(fragment: ShuffleQueryFragmentExec, _) =>
188+
fragment.plan.createShuffledRDD(Some(partitionStartIndices))
189+
case _ => throw new SparkException("Invalid child for CoalescedShuffleReaderExec")
190+
}
179191
}
180192
cachedShuffleRDD
181193
}

sql/core/src/test/scala/org/apache/spark/sql/execution/ReduceNumShufflePartitionsSuite.scala

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,10 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
558558

559559
val leafFragments = level1Fragments.flatMap { fragment =>
560560
// All of the child fragments of result fragment have only one child fragment.
561-
val children = fragment.plan.collect { case q: QueryFragmentExec => q }
561+
val children = fragment.plan.collect {
562+
case q: QueryFragmentExec => q
563+
case r: ReusedQueryFragmentExec => r.child
564+
}
562565
assert(children.length == 1)
563566
children
564567
}

0 commit comments

Comments
 (0)