1717
1818package org .apache .spark .sql .execution .adaptive .rule
1919
20- import scala .collection .mutable .ArrayBuffer
2120import scala .concurrent .duration .Duration
2221
2322import org .apache .spark .MapOutputStatistics
23+ import org .apache .spark .SparkException
2424import org .apache .spark .rdd .RDD
2525import org .apache .spark .sql .catalyst .InternalRow
2626import org .apache .spark .sql .catalyst .expressions .Attribute
2727import org .apache .spark .sql .catalyst .plans .physical .{Partitioning , UnknownPartitioning }
2828import org .apache .spark .sql .catalyst .rules .Rule
2929import org .apache .spark .sql .execution .{ShuffledRowRDD , SparkPlan , UnaryExecNode }
30- import org .apache .spark .sql .execution .adaptive .{QueryFragmentExec , ShuffleQueryFragmentExec }
30+ import org .apache .spark .sql .execution .adaptive .{QueryFragmentExec , ReusedQueryFragmentExec , ShuffleQueryFragmentExec }
3131import org .apache .spark .sql .internal .SQLConf
3232import org .apache .spark .util .ThreadUtils
3333
@@ -61,7 +61,9 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
6161 ThreadUtils .awaitResult(metricsFuture, Duration .Zero )
6262 }
6363
64- val allFragmentLeaves = plan.collectLeaves().forall(_.isInstanceOf [QueryFragmentExec ])
64+ val allFragmentLeaves = plan.collectLeaves().forall { node =>
65+ node.isInstanceOf [QueryFragmentExec ] || node.isInstanceOf [ReusedQueryFragmentExec ]
66+ }
6567
6668 if (allFragmentLeaves) {
6769 // ShuffleQueryFragment gives null mapOutputStatistics when the input RDD has 0 partitions,
@@ -76,6 +78,8 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
7678 // number of output partitions.
7779 case fragment : ShuffleQueryFragmentExec =>
7880 CoalescedShuffleReaderExec (fragment, partitionStartIndices)
81+ case r@ ReusedQueryFragmentExec (fragment : ShuffleQueryFragmentExec , output) =>
82+ CoalescedShuffleReaderExec (r, partitionStartIndices)
7983 }
8084 } else {
8185 plan
@@ -152,7 +156,9 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
152156 partitionStartIndices += i
153157 // reset postShuffleInputSize.
154158 postShuffleInputSize = nextShuffleInputSize
155- } else postShuffleInputSize += nextShuffleInputSize
159+ } else {
160+ postShuffleInputSize += nextShuffleInputSize
161+ }
156162
157163 i += 1
158164 }
@@ -162,7 +168,7 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
162168}
163169
164170case class CoalescedShuffleReaderExec (
165- child : ShuffleQueryFragmentExec ,
171+ child : SparkPlan ,
166172 partitionStartIndices : Array [Int ]) extends UnaryExecNode {
167173
168174 override def output : Seq [Attribute ] = child.output
@@ -175,7 +181,13 @@ case class CoalescedShuffleReaderExec(
175181
176182 override protected def doExecute (): RDD [InternalRow ] = {
177183 if (cachedShuffleRDD == null ) {
178- cachedShuffleRDD = child.plan.createShuffledRDD(Some (partitionStartIndices))
184+ cachedShuffleRDD = child match {
185+ case fragment : ShuffleQueryFragmentExec =>
186+ fragment.plan.createShuffledRDD(Some (partitionStartIndices))
187+ case ReusedQueryFragmentExec (fragment : ShuffleQueryFragmentExec , _) =>
188+ fragment.plan.createShuffledRDD(Some (partitionStartIndices))
189+ case _ => throw new SparkException (" Invalid child for CoalescedShuffleReaderExec" )
190+ }
179191 }
180192 cachedShuffleRDD
181193 }
0 commit comments