@@ -113,38 +113,50 @@ abstract class QueryStage extends UnaryExecNode {
113113 val queryStageInputs : Seq [ShuffleQueryStageInput ] = child.collect {
114114 case input : ShuffleQueryStageInput if ! input.isLocalShuffle => input
115115 }
116- val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics)
117- .filter(_ != null ).toArray
118- // Right now, Adaptive execution only support HashPartitionings.
119- val supportAdaptive = queryStageInputs.forall{
116+
117+ val skewedShuffleQueryStageInputs : Seq [SkewedShuffleQueryStageInput ] = child.collect {
118+ case input : SkewedShuffleQueryStageInput => input
119+ }
120+
121+ val leafNodes = child.collect {
122+ case s : SparkPlan if s.children.isEmpty => s
123+ }
124+
125+ // Ensure all leaf nodes are shuffle query stages
126+ if (leafNodes.length == queryStageInputs.length + skewedShuffleQueryStageInputs.length) {
127+ val childMapOutputStatistics = queryStageInputs.map(_.childStage.mapOutputStatistics)
128+ .filter(_ != null ).toArray
129+ // Right now, Adaptive execution only support HashPartitionings.
130+ val supportAdaptive = queryStageInputs.forall {
120131 _.outputPartitioning match {
121132 case hash : HashPartitioning => true
122133 case collection : PartitioningCollection =>
123134 collection.partitionings.forall(_.isInstanceOf [HashPartitioning ])
124135 case _ => false
125136 }
126- }
137+ }
127138
128- if (childMapOutputStatistics.length > 0 && supportAdaptive) {
129- val exchangeCoordinator = new ExchangeCoordinator (
130- conf.targetPostShuffleInputSize,
131- conf.adaptiveTargetPostShuffleRowCount,
132- conf.minNumPostShufflePartitions)
133-
134- if (queryStageInputs.length == 2 && queryStageInputs.forall(_.skewedPartitions.isDefined)) {
135- // If a skewed join is detected and optimized, we will omit the skewed partitions when
136- // estimate the partition start and end indices.
137- val (partitionStartIndices, partitionEndIndices) =
138- exchangeCoordinator.estimatePartitionStartEndIndices(
139- childMapOutputStatistics, queryStageInputs(0 ).skewedPartitions.get)
140- queryStageInputs.foreach { i =>
141- i.partitionStartIndices = Some (partitionStartIndices)
142- i.partitionEndIndices = Some (partitionEndIndices)
139+ if (childMapOutputStatistics.length > 0 && supportAdaptive) {
140+ val exchangeCoordinator = new ExchangeCoordinator (
141+ conf.targetPostShuffleInputSize,
142+ conf.adaptiveTargetPostShuffleRowCount,
143+ conf.minNumPostShufflePartitions)
144+
145+ if (queryStageInputs.length == 2 && queryStageInputs.forall(_.skewedPartitions.isDefined)) {
146+ // If a skewed join is detected and optimized, we will omit the skewed partitions when
147+ // estimate the partition start and end indices.
148+ val (partitionStartIndices, partitionEndIndices) =
149+ exchangeCoordinator.estimatePartitionStartEndIndices(
150+ childMapOutputStatistics, queryStageInputs(0 ).skewedPartitions.get)
151+ queryStageInputs.foreach { i =>
152+ i.partitionStartIndices = Some (partitionStartIndices)
153+ i.partitionEndIndices = Some (partitionEndIndices)
154+ }
155+ } else {
156+ val partitionStartIndices =
157+ exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics)
158+ queryStageInputs.foreach(_.partitionStartIndices = Some (partitionStartIndices))
143159 }
144- } else {
145- val partitionStartIndices =
146- exchangeCoordinator.estimatePartitionStartIndices(childMapOutputStatistics)
147- queryStageInputs.foreach(_.partitionStartIndices = Some (partitionStartIndices))
148160 }
149161 }
150162
0 commit comments