Skip to content

create partial shuffle reader #3

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ case class AdaptiveSparkPlanExec(
// Here the 'OptimizeSkewedPartitions' rule should be executed
// before 'ReduceNumShufflePartitions', as the skewed partition handled
// in 'OptimizeSkewedPartitions' rule, should be omitted in 'ReduceNumShufflePartitions'.
OptimizeSkewedPartitions(conf),
OptimizeSkewedJoin(conf),
ReduceNumShufflePartitions(conf),
// The rule of 'OptimizeLocalShuffleReader' need to make use of the 'partitionStartIndices'
// in 'ReduceNumShufflePartitions' rule. So it must be after 'ReduceNumShufflePartitions' rule.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,9 +165,9 @@ case class LocalShuffleReaderExec(
// before shuffle.
if (partitionStartIndicesPerMapper.forall(_.length == 1)) {
child match {
case ShuffleQueryStageExec(_, s: ShuffleExchangeExec, _) =>
case ShuffleQueryStageExec(_, s: ShuffleExchangeExec) =>
s.child.outputPartitioning
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec), _) =>
case ShuffleQueryStageExec(_, r @ ReusedExchangeExec(_, s: ShuffleExchangeExec)) =>
s.child.outputPartitioning match {
case e: Expression => r.updateAttr(e).asInstanceOf[Partitioning]
case other => other
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,11 @@ import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnknownPartitioning}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec
import org.apache.spark.sql.execution.joins.SortMergeJoinExec
import org.apache.spark.sql.internal.SQLConf

case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
case class OptimizeSkewedJoin(conf: SQLConf) extends Rule[SparkPlan] {

private val supportedJoinTypes =
Inner :: Cross :: LeftSemi :: LeftAnti :: LeftOuter :: RightOuter :: Nil
Expand Down Expand Up @@ -115,8 +116,8 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {

def handleSkewJoin(plan: SparkPlan): SparkPlan = plan.transformUp {
case smj @ SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
SortExec(_, _, left: ShuffleQueryStageExec, _),
SortExec(_, _, right: ShuffleQueryStageExec, _))
s1 @ SortExec(_, _, left: ShuffleQueryStageExec, _),
s2 @ SortExec(_, _, right: ShuffleQueryStageExec, _))
if supportedJoinTypes.contains(joinType) =>
val leftStats = getStatistics(left)
val rightStats = getStatistics(right)
Expand Down Expand Up @@ -166,26 +167,20 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
}
// TODO: we may can optimize the sort merge join to broad cast join after
// obtaining the raw data size of per partition,
val leftSkewedReader = SkewedShufflePartitionReader(
val leftSkewedReader = SkewedPartitionReaderExec(
left, partitionId, leftMapIdStartIndices(i), leftEndMapId)
val leftSort = smj.left.asInstanceOf[SortExec].copy(child = leftSkewedReader)

val rightSkewedReader = SkewedShufflePartitionReader(right, partitionId,
rightMapIdStartIndices(j), rightEndMapId)
val rightSort = smj.right.asInstanceOf[SortExec].copy(child = rightSkewedReader)
subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
leftSort, rightSort)
val rightSkewedReader = SkewedPartitionReaderExec(right, partitionId,
rightMapIdStartIndices(j), rightEndMapId)
subJoins += SortMergeJoinExec(leftKeys, rightKeys, joinType, condition,
s1.copy(child = leftSkewedReader), s2.copy(child = rightSkewedReader))
}
}
}
logDebug(s"number of skewed partitions is ${skewedPartitions.size}")
if (skewedPartitions.nonEmpty) {
val optimizedSmj = smj.transformDown {
case sort @ SortExec(_, _, shuffleStage: ShuffleQueryStageExec, _) =>
val newStage = shuffleStage.copy(
excludedPartitions = skewedPartitions.toSet)
newStage.resultOption = shuffleStage.resultOption
sort.copy(child = newStage)
sort.copy(child = PartialShuffleReaderExec(shuffleStage, skewedPartitions.toSet))
}
subJoins += optimizedSmj
UnionExec(subJoins)
Expand Down Expand Up @@ -221,15 +216,15 @@ case class OptimizeSkewedPartitions(conf: SQLConf) extends Rule[SparkPlan] {
/**
* A wrapper of shuffle query stage, which submits one reduce task to read a single
* shuffle partition 'partitionIndex' produced by the mappers in range [startMapIndex, endMapIndex).
* This is used to handle the skewed partitions.
* This is used to increase the parallelism when reading skewed partitions.
*
* @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange
* node during canonicalization.
* @param partitionIndex The pre shuffle partition index.
* @param startMapIndex The start map index.
* @param endMapIndex The end map index.
*/
case class SkewedShufflePartitionReader(
case class SkewedPartitionReaderExec(
child: QueryStageExec,
partitionIndex: Int,
startMapIndex: Int,
Expand All @@ -242,10 +237,6 @@ case class SkewedShufflePartitionReader(
}
private var cachedSkewedShuffleRDD: SkewedShuffledRowRDD = null

override def nodeName: String = s"SkewedShuffleReader SkewedShuffleQueryStage: ${child}" +
s" SkewedPartition: ${partitionIndex} startMapIndex: ${startMapIndex}" +
s" endMapIndex: ${endMapIndex}"

override def doExecute(): RDD[InternalRow] = {
if (cachedSkewedShuffleRDD == null) {
cachedSkewedShuffleRDD = child match {
Expand All @@ -258,3 +249,45 @@ case class SkewedShufflePartitionReader(
cachedSkewedShuffleRDD
}
}

/**
* A wrapper of shuffle query stage, which skips some partitions when reading the shuffle blocks.
*
* @param child It's usually `ShuffleQueryStageExec`, but can be the shuffle exchange node during
* canonicalization.
* @param excludedPartitions The partitions to skip when reading.
*/
case class PartialShuffleReaderExec(
child: QueryStageExec,
excludedPartitions: Set[Int]) extends UnaryExecNode {

override def output: Seq[Attribute] = child.output

override def outputPartitioning: Partitioning = {
UnknownPartitioning(1)
}

private def shuffleExchange(): ShuffleExchangeExec = child match {
case stage: ShuffleQueryStageExec => stage.shuffle
case _ =>
throw new IllegalStateException("operating on canonicalization plan")
}

private def getPartitionIndexRanges(): Array[(Int, Int)] = {
val length = shuffleExchange().shuffleDependency.partitioner.numPartitions
(0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 1)).toArray
}

private var cachedShuffleRDD: RDD[InternalRow] = null

override def doExecute(): RDD[InternalRow] = {
if (cachedShuffleRDD == null) {
cachedShuffleRDD = if (excludedPartitions.isEmpty) {
child.execute()
} else {
shuffleExchange().createShuffledRDD(Some(getPartitionIndexRanges()))
}
}
cachedShuffleRDD
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

package org.apache.spark.sql.execution.adaptive

import scala.collection.mutable.ArrayBuffer
import scala.concurrent.Future

import org.apache.spark.{FutureAction, MapOutputStatistics}
Expand Down Expand Up @@ -135,8 +134,7 @@ abstract class QueryStageExec extends LeafExecNode {
*/
case class ShuffleQueryStageExec(
override val id: Int,
override val plan: SparkPlan,
val excludedPartitions: Set[Int] = Set.empty) extends QueryStageExec {
override val plan: SparkPlan) extends QueryStageExec {

@transient val shuffle = plan match {
case s: ShuffleExchangeExec => s
Expand All @@ -163,26 +161,6 @@ case class ShuffleQueryStageExec(
case _ =>
}
}

private def getPartitionIndexRanges(): Array[(Int, Int)] = {
val length = shuffle.shuffleDependency.partitioner.numPartitions
(0 until length).filterNot(excludedPartitions.contains).map(i => (i, i + 1)).toArray
}

private var cachedShuffleRDD: RDD[InternalRow] = null

override def doExecute(): RDD[InternalRow] = {
if (cachedShuffleRDD == null) {
cachedShuffleRDD = excludedPartitions match {
case e if e.isEmpty =>
plan.execute()
case _ =>
shuffle.createShuffledRDD(
Some(getPartitionIndexRanges()))
}
}
cachedShuffleRDD
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

package org.apache.spark.sql.execution.adaptive

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.{ArrayBuffer, HashSet}

import org.apache.spark.MapOutputStatistics
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -54,22 +54,28 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
if (!conf.reducePostShufflePartitionsEnabled) {
return plan
}
// we need skip the leaf node of 'SkewedShufflePartitionReader'
val leafNodes = plan.collectLeaves().filter(!_.isInstanceOf[SkewedShufflePartitionReader])
// 'SkewedShufflePartitionReader' is added by us, so it's safe to ignore it when changing
// number of reducers.
val leafNodes = plan.collectLeaves().filter(!_.isInstanceOf[SkewedPartitionReaderExec])
if (!leafNodes.forall(_.isInstanceOf[QueryStageExec])) {
// If not all leaf nodes are query stages, it's not safe to reduce the number of
// shuffle partitions, because we may break the assumption that all children of a spark plan
// have same number of output partitions.
return plan
}

def collectShuffleStages(plan: SparkPlan): Seq[ShuffleQueryStageExec] = plan match {
def collectShuffles(plan: SparkPlan): Seq[SparkPlan] = plan match {
case _: LocalShuffleReaderExec => Nil
case p: PartialShuffleReaderExec => Seq(p)
case stage: ShuffleQueryStageExec => Seq(stage)
case _ => plan.children.flatMap(collectShuffleStages)
case _ => plan.children.flatMap(collectShuffles)
}

val shuffleStages = collectShuffleStages(plan)
val shuffles = collectShuffles(plan)
val shuffleStages = shuffles.map {
case PartialShuffleReaderExec(s: ShuffleQueryStageExec, _) => s
case s: ShuffleQueryStageExec => s
}
// ShuffleExchanges introduced by repartition do not support changing the number of partitions.
// We change the number of partitions in the stage only if all the ShuffleExchanges support it.
if (!shuffleStages.forall(_.shuffle.canChangeNumPartitions)) {
Expand All @@ -88,18 +94,31 @@ case class ReduceNumShufflePartitions(conf: SQLConf) extends Rule[SparkPlan] {
// partition) and a result of a SortMergeJoin (multiple partitions).
val distinctNumPreShufflePartitions =
validMetrics.map(stats => stats.bytesByPartitionId.length).distinct
val distinctExcludedPartitions = shuffleStages.map(_.excludedPartitions).distinct
val distinctExcludedPartitions = shuffles.map {
case PartialShuffleReaderExec(_, excludedPartitions) => excludedPartitions
case _: ShuffleQueryStageExec => Set.empty[Int]
}.distinct
if (validMetrics.nonEmpty && distinctNumPreShufflePartitions.length == 1
&& distinctExcludedPartitions.length == 1) {
val excludedPartitions = shuffleStages.head.excludedPartitions
val excludedPartitions = distinctExcludedPartitions.head
val partitionIndices = estimatePartitionStartAndEndIndices(
validMetrics.toArray, excludedPartitions)
// This transformation adds new nodes, so we must use `transformUp` here.
plan.transformUp {
// even for shuffle exchange whose input RDD has 0 partition, we should still update its
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
// number of output partitions.
case stage: ShuffleQueryStageExec =>
// Even for shuffle exchange whose input RDD has 0 partition, we should still update its
// `partitionStartIndices`, so that all the leaf shuffles in a stage have the same
// number of output partitions.
val visitedStages = HashSet.empty[Int]
plan.transformDown {
// Replace `PartialShuffleReaderExec` with `CoalescedShuffleReaderExec`, which keeps the
// "excludedPartition" requirement and also merges some partitions.
case PartialShuffleReaderExec(stage: ShuffleQueryStageExec, _) =>
visitedStages.add(stage.id)
CoalescedShuffleReaderExec(stage, partitionIndices)

// We are doing `transformDown`, so the `ShuffleQueryStageExec` may already be optimized
// and wrapped by `CoalescedShuffleReaderExec`.
case stage: ShuffleQueryStageExec if !visitedStages.contains(stage.id) =>
visitedStages.add(stage.id)
CoalescedShuffleReaderExec(stage, partitionIndices)
}
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -533,7 +533,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA
val finalPlan = resultDf.queryExecution.executedPlan
.asInstanceOf[AdaptiveSparkPlanExec].executedPlan
assert(finalPlan.collect {
case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r
case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r
}.length == 2)
assert(finalPlan.collect { case p: CoalescedShuffleReaderExec => p }.length == 3)

Expand Down Expand Up @@ -566,7 +566,7 @@ class ReduceNumShufflePartitionsSuite extends SparkFunSuite with BeforeAndAfterA

val reusedStages = level1Stages.flatMap { stage =>
stage.plan.collect {
case ShuffleQueryStageExec(_, r: ReusedExchangeExec, _) => r
case ShuffleQueryStageExec(_, r: ReusedExchangeExec) => r
}
}
assert(reusedStages.length == 1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ class AdaptiveQueryExecSuite

private def findReusedExchange(plan: SparkPlan): Seq[ReusedExchangeExec] = {
collectInPlanAndSubqueries(plan) {
case ShuffleQueryStageExec(_, e: ReusedExchangeExec, _) => e
case ShuffleQueryStageExec(_, e: ReusedExchangeExec) => e
case BroadcastQueryStageExec(_, e: ReusedExchangeExec) => e
}
}
Expand Down