Skip to content

Commit 7f83c9c

Browse files
author
luzhonghao
committed
auto calculate the initial partition number with ae (apache#61)
1 parent 84394e8 commit 7f83c9c

File tree

3 files changed

+10
-9
lines changed

3 files changed

+10
-9
lines changed

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/HandleSkewedJoin.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
9191
partitionId: Int,
9292
medianSize: Long,
9393
medianRowCount: Long): Array[Int] = {
94-
val stats = queryStageInput.childStage.stats
94+
val stats = queryStageInput.childStage.statsPlan
9595
val size = stats.bytesByPartitionId.get(partitionId)
9696
val rowCount = stats.recordStatistics.get.recordsByPartitionId(partitionId)
9797
val factor = Math.max(size / medianSize, rowCount / medianRowCount)
@@ -110,8 +110,8 @@ case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
110110
left: QueryStageInput,
111111
right: QueryStageInput): Boolean = {
112112
supportedJoinTypes.contains(joinType) &&
113-
left.childStage.stats.getPartitionStatistics.isDefined &&
114-
right.childStage.stats.getPartitionStatistics.isDefined
113+
left.childStage.statsPlan.getPartitionStatistics.isDefined &&
114+
right.childStage.statsPlan.getPartitionStatistics.isDefined
115115
}
116116

117117
private def supportSplitOnLeftPartition(joinType: JoinType) = joinType != RightOuter
@@ -128,8 +128,8 @@ case class HandleSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {
128128
SortExec(_, _, right: ShuffleQueryStageInput, _))
129129
if supportOptimization(joinType, left, right) =>
130130

131-
val leftStats = left.childStage.stats.getPartitionStatistics.get
132-
val rightStats = right.childStage.stats.getPartitionStatistics.get
131+
val leftStats = left.childStage.statsPlan.getPartitionStatistics.get
132+
val rightStats = right.childStage.statsPlan.getPartitionStatistics.get
133133
val numPartitions = leftStats.bytesByPartitionId.length
134134
val (leftMedSize, leftMedRowCount) = medianSizeAndRowCount(leftStats)
135135
val (rightMedSize, rightMedRowCount) = medianSizeAndRowCount(rightStats)

sql/core/src/main/scala/org/apache/spark/sql/execution/adaptive/OptimizeJoin.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,8 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
4040
}
4141

4242
private def canBroadcast(plan: SparkPlan): Boolean = {
43-
plan.stats.sizeInBytes >= 0 && plan.stats.sizeInBytes <= conf.adaptiveBroadcastJoinThreshold
43+
val ret = plan.statsPlan.sizeInBytes >= 0
44+
ret && plan.statsPlan.sizeInBytes <= conf.adaptiveBroadcastJoinThreshold
4445
}
4546

4647
private def removeSort(plan: SparkPlan): SparkPlan = {
@@ -85,9 +86,9 @@ case class OptimizeJoin(conf: SQLConf) extends Rule[SparkPlan] {
8586
// reading them in local shuffle read.
8687
broadcastSidePlan match {
8788
case broadcast: ShuffleQueryStageInput
88-
if broadcast.childStage.stats.bytesByPartitionId.isDefined =>
89+
if broadcast.childStage.statsPlan.bytesByPartitionId.isDefined =>
8990
val (startIndicies, endIndicies) = calculatePartitionStartEndIndices(broadcast.childStage
90-
.stats.bytesByPartitionId.get)
91+
.statsPlan.bytesByPartitionId.get)
9192
childrenPlans.foreach {
9293
case input: ShuffleQueryStageInput =>
9394
input.partitionStartIndices = Some(startIndicies)

sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -351,7 +351,7 @@ class PlannerSuite extends SharedSQLContext {
351351
)
352352
withSQLConf(SQLConf.SHUFFLE_TARGET_POSTSHUFFLE_INPUT_SIZE.key -> "1") {
353353

354-
val totalInputFileSize = inputPlan.collectLeaves().map(_.stats.sizeInBytes).sum
354+
val totalInputFileSize = inputPlan.collectLeaves().map(_.statsPlan.sizeInBytes).sum
355355
val expectedNum = Math.ceil(
356356
totalInputFileSize.toLong * 1.0 / conf.targetPostShuffleInputSize).toInt
357357

0 commit comments

Comments
 (0)