Skip to content

[SPARK-52024][SQL] Support cancel ShuffleQueryStage when propagate empty relations #50814

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,11 @@ object SparkException {
}
}

/**
* Exception which indicates that the queryStage should be cancelled.
*/
private[spark] class SparkAQEStageCancelException extends RuntimeException

/**
* Exception thrown when execution of some user code in the driver process fails, e.g.
* accumulator update fails or failure in takeOrdered (user supplies an Ordering implementation
Expand Down
15 changes: 8 additions & 7 deletions core/src/main/scala/org/apache/spark/FutureAction.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ trait FutureAction[T] extends Future[T] {
/**
* Cancels the execution of this action with an optional reason.
*/
def cancel(reason: Option[String]): Unit
def cancel(reason: Option[String], quiet: Boolean = false): Unit

/**
* Cancels the execution of this action.
Expand Down Expand Up @@ -119,9 +119,9 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:

@volatile private var _cancelled: Boolean = false

override def cancel(reason: Option[String]): Unit = {
override def cancel(reason: Option[String] = None, quiet: Boolean = false): Unit = {
_cancelled = true
jobWaiter.cancel(reason)
jobWaiter.cancel(reason, quiet)
}

override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
Expand Down Expand Up @@ -193,10 +193,11 @@ class ComplexFutureAction[T](run : JobSubmitter => Future[T])
// A promise used to signal the future.
private val p = Promise[T]().completeWith(run(jobSubmitter))

override def cancel(reason: Option[String]): Unit = synchronized {
_cancelled = true
p.tryFailure(new SparkException("Action has been cancelled"))
subActions.foreach(_.cancel(reason))
override def cancel(reason: Option[String] = None, quiet: Boolean = false): Unit =
synchronized {
_cancelled = true
p.tryFailure(new SparkException("Action has been cancelled"))
subActions.foreach(_.cancel(reason, quiet))
}

private def jobSubmitter = new JobSubmitter {
Expand Down
28 changes: 20 additions & 8 deletions core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1099,9 +1099,9 @@ private[spark] class DAGScheduler(
/**
* Cancel a job that is running or waiting in the queue.
*/
def cancelJob(jobId: Int, reason: Option[String]): Unit = {
def cancelJob(jobId: Int, reason: Option[String], quiet: Boolean = false): Unit = {
logInfo(log"Asked to cancel job ${MDC(JOB_ID, jobId)}")
eventProcessLoop.post(JobCancelled(jobId, reason))
eventProcessLoop.post(JobCancelled(jobId, reason, quiet))
}

/**
Expand Down Expand Up @@ -2820,13 +2820,20 @@ private[spark] class DAGScheduler(
}
}

private[scheduler] def handleJobCancellation(jobId: Int, reason: Option[String]): Unit = {
private[scheduler] def handleJobCancellation(
jobId: Int, reason: Option[String], quiet: Boolean = false): Unit = {
if (!jobIdToStageIds.contains(jobId)) {
logDebug("Trying to cancel unregistered job " + jobId)
} else {
val error = if (quiet) {
new SparkException("Job %d cancelled %s".format(jobId, reason.getOrElse("")))
} else {
SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null)
}
failJobAndIndependentStages(
job = jobIdToActiveJob(jobId),
error = SparkCoreErrors.sparkJobCancelled(jobId, reason.getOrElse(""), null)
error = error,
quiet = quiet
)
}
}
Expand Down Expand Up @@ -2960,12 +2967,17 @@ private[spark] class DAGScheduler(
/** Fails a job and all stages that are only used by that job, and cleans up relevant state. */
private def failJobAndIndependentStages(
job: ActiveJob,
error: Exception): Unit = {
error: Exception,
quiet: Boolean = false): Unit = {
if (cancelRunningIndependentStages(job, error.getMessage)) {
// SPARK-15783 important to cleanup state first, just for tests where we have some asserts
// against the state. Otherwise we have a *little* bit of flakiness in the tests.
cleanupStateForJobAndIndependentStages(job)
job.listener.jobFailed(error)
if (quiet) {
job.listener.jobCancel(error)
} else {
job.listener.jobFailed(error)
}
listenerBus.post(SparkListenerJobEnd(job.jobId, clock.getTimeMillis(), JobFailed(error)))
}
}
Expand Down Expand Up @@ -3120,8 +3132,8 @@ private[scheduler] class DAGSchedulerEventProcessLoop(dagScheduler: DAGScheduler
case StageCancelled(stageId, reason) =>
dagScheduler.handleStageCancellation(stageId, reason)

case JobCancelled(jobId, reason) =>
dagScheduler.handleJobCancellation(jobId, reason)
case JobCancelled(jobId, reason, quiet) =>
dagScheduler.handleJobCancellation(jobId, reason, quiet)

case JobGroupCancelled(groupId, cancelFutureJobs, reason) =>
dagScheduler.handleJobGroupCancelled(groupId, cancelFutureJobs, reason)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,8 @@ private[scheduler] case class StageCancelled(

private[scheduler] case class JobCancelled(
jobId: Int,
reason: Option[String])
reason: Option[String],
quiet: Boolean = false)
extends DAGSchedulerEvent

private[scheduler] case class JobGroupCancelled(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@ package org.apache.spark.scheduler
private[spark] trait JobListener {
def taskSucceeded(index: Int, result: Any): Unit
def jobFailed(exception: Exception): Unit

def jobCancel(exception: Exception): Unit = {}
}
11 changes: 9 additions & 2 deletions core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicInteger

import scala.concurrent.{Future, Promise}

import org.apache.spark.SparkAQEStageCancelException
import org.apache.spark.internal.Logging

/**
Expand Down Expand Up @@ -49,8 +50,8 @@ private[spark] class JobWaiter[T](
* cancellation itself is handled asynchronously. After the low level scheduler cancels
* all the tasks belonging to this job, it will fail this job with a SparkException.
*/
def cancel(reason: Option[String]): Unit = {
dagScheduler.cancelJob(jobId, reason)
def cancel(reason: Option[String] = None, quiet: Boolean = false): Unit = {
dagScheduler.cancelJob(jobId, reason, quiet)
}

/**
Expand All @@ -76,4 +77,10 @@ private[spark] class JobWaiter[T](
}
}

override def jobCancel(exception: Exception): Unit = {
if (!jobPromise.tryFailure(new SparkAQEStageCancelException())) {
logWarning("Ignore failure", exception)
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
protected def empty(plan: LogicalPlan): LogicalPlan =
LocalRelation(plan.output, data = Seq.empty, isStreaming = plan.isStreaming)

protected def collectCancelableCandidates(maybeCancel: LogicalPlan*): Unit = {}

// Construct a project list from plan's output, while the value is always NULL.
private def nullValueProjectList(plan: LogicalPlan): Seq[NamedExpression] =
plan.output.map{ a => Alias(cast(Literal(null), a.dataType), a.name)(a.exprId) }
Expand All @@ -69,7 +71,8 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup

protected def commonApplyFunc: PartialFunction[LogicalPlan, LogicalPlan] = {
case p: Union if p.children.exists(isEmpty) =>
val newChildren = p.children.filterNot(isEmpty)
val (candidates, newChildren) = p.children.partition(isEmpty)
collectCancelableCandidates(candidates: _*)
if (newChildren.isEmpty) {
empty(p)
} else {
Expand Down Expand Up @@ -106,29 +109,47 @@ abstract class PropagateEmptyRelationBase extends Rule[LogicalPlan] with CastSup
}
if (isLeftEmpty || isRightEmpty || isFalseCondition) {
joinType match {
case _: InnerLike => empty(p)
case _: InnerLike =>
collectCancelableCandidates(p.left, p.right)
empty(p)
// Intersect is handled as LeftSemi by `ReplaceIntersectWithSemiJoin` rule.
// Except is handled as LeftAnti by `ReplaceExceptWithAntiJoin` rule.
case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty => empty(p)
case LeftSemi if isRightEmpty | isFalseCondition => empty(p)
case LeftOuter | LeftSemi | LeftAnti if isLeftEmpty =>
collectCancelableCandidates(p.right)
empty(p)
case LeftSemi if isRightEmpty | isFalseCondition =>
if (isRightEmpty) {
collectCancelableCandidates(p.left)
} else {
collectCancelableCandidates(p.left, p.right)
}
empty(p)
case LeftAnti if (isRightEmpty | isFalseCondition) && canExecuteWithoutJoin(p.left) =>
if (!isRightEmpty) {
collectCancelableCandidates(p.right)
}
p.left
case FullOuter if isLeftEmpty && isRightEmpty => empty(p)
case LeftOuter | FullOuter if isRightEmpty && canExecuteWithoutJoin(p.left) =>
Project(p.left.output ++ nullValueProjectList(p.right), p.left)
case RightOuter if isRightEmpty => empty(p)
case RightOuter if isRightEmpty =>
collectCancelableCandidates(p.left)
empty(p)
case RightOuter | FullOuter if isLeftEmpty && canExecuteWithoutJoin(p.right) =>
Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
case LeftOuter if isFalseCondition && canExecuteWithoutJoin(p.left) =>
collectCancelableCandidates(p.right)
Project(p.left.output ++ nullValueProjectList(p.right), p.left)
case RightOuter if isFalseCondition && canExecuteWithoutJoin(p.right) =>
collectCancelableCandidates(p.left)
Project(nullValueProjectList(p.left) ++ p.right.output, p.right)
case _ => p
}
} else if (joinType == LeftSemi && conditionOpt.isEmpty &&
nonEmpty(p.right) && canExecuteWithoutJoin(p.left)) {
p.left
} else if (joinType == LeftAnti && conditionOpt.isEmpty && nonEmpty(p.right)) {
collectCancelableCandidates(p.left)
empty(p)
} else {
p
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -894,6 +894,14 @@ object SQLConf {
.checkValue(_ > 0, "The initial number of partitions must be positive.")
.createOptional

val ADAPTIVE_EMPTY_TRIGGER_CANCEL_ENABLED =
buildConf("spark.sql.adaptive.empty.trigger.cancel.enabled")
.doc(s"When true and '${ADAPTIVE_EXECUTION_ENABLED.key}' is true, when propagate " +
" empty relation, Spark will try to cancel QueryStage that is unnecessary.")
.version("3.5.5")
.booleanConf
.createWithDefault(true)

lazy val ALLOW_COLLATIONS_IN_MAP_KEYS =
buildConf("spark.sql.collation.allowInMapKeys")
.doc("Allow for non-UTF8_BINARY collated strings inside of map's keys")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ import org.apache.spark.util.Utils
/**
* The optimizer for re-optimizing the logical plan used by AdaptiveSparkPlanExec.
*/
class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: Seq[Rule[LogicalPlan]])
class AQEOptimizer(
conf: SQLConf,
stagesToCancel: collection.mutable.Map[Int, (String, ExchangeQueryStageExec)],
extendedRuntimeOptimizerRules: Seq[Rule[LogicalPlan]])
extends RuleExecutor[LogicalPlan] {

private def fixedPoint =
Expand All @@ -39,7 +42,7 @@ class AQEOptimizer(conf: SQLConf, extendedRuntimeOptimizerRules: Seq[Rule[Logica

private val defaultBatches = Seq(
Batch("Propagate Empty Relations", fixedPoint,
AQEPropagateEmptyRelation,
AQEPropagateEmptyRelation(stagesToCancel),
ConvertToLocalRelation,
UpdateAttributeNullability),
Batch("Dynamic Join Selection", Once, DynamicJoinSelection),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@ import org.apache.spark.sql.catalyst.planning.ExtractSingleColumnNullAwareAntiJo
import org.apache.spark.sql.catalyst.plans.logical.EmptyRelation
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.TreePattern.{LOCAL_RELATION, LOGICAL_QUERY_STAGE, TRUE_OR_FALSE_LITERAL}
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.sql.execution.aggregate.BaseAggregateExec
import org.apache.spark.sql.execution.exchange.{REPARTITION_BY_COL, REPARTITION_BY_NUM, ShuffleExchangeLike}
import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
import org.apache.spark.sql.internal.SQLConf

/**
* This rule runs in the AQE optimizer and optimizes more cases
Expand All @@ -33,7 +35,9 @@ import org.apache.spark.sql.execution.joins.HashedRelationWithAllNullKeys
* Broadcasted [[HashedRelation]] is [[HashedRelationWithAllNullKeys]]. Eliminate join to an
* empty [[LocalRelation]].
*/
object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {
case class AQEPropagateEmptyRelation(
stagesToCancel: collection.mutable.Map[Int, (String, ExchangeQueryStageExec)])
extends PropagateEmptyRelationBase {
override protected def isEmpty(plan: LogicalPlan): Boolean =
super.isEmpty(plan) || (!isRootRepartition(plan) && getEstimatedRowCount(plan).contains(0))

Expand All @@ -42,6 +46,17 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {

override protected def empty(plan: LogicalPlan): LogicalPlan = EmptyRelation(plan)

override protected def collectCancelableCandidates(candidates: LogicalPlan*): Unit = {
if (!conf.getConf(SQLConf.ADAPTIVE_EMPTY_TRIGGER_CANCEL_ENABLED)) return
candidates.foreach(_.foreach {
case LogicalQueryStage(_, physicalPlan: SparkPlan) =>
physicalPlan.collect {
case s: ShuffleQueryStageExec if !s.isMaterialized => s
}.foreach(s => stagesToCancel(s.id) = ("empty relation", s))
case _ =>
})
}

private def isRootRepartition(plan: LogicalPlan): Boolean = plan match {
case l: LogicalQueryStage if l.getTagValue(ROOT_REPARTITION).isDefined => true
case _ => false
Expand Down Expand Up @@ -77,6 +92,7 @@ object AQEPropagateEmptyRelation extends PropagateEmptyRelationBase {

private def eliminateSingleColumnNullAwareAntiJoin: PartialFunction[LogicalPlan, LogicalPlan] = {
case j @ ExtractSingleColumnNullAwareAntiJoin(_, _) if isRelationWithAllNullKeys(j.right) =>
collectCancelableCandidates(j.left)
empty(j)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ import scala.concurrent.ExecutionContext
import scala.jdk.CollectionConverters._
import scala.util.control.NonFatal

import org.apache.spark.SparkException
import org.apache.spark.broadcast
import org.apache.spark.{broadcast, SparkAQEStageCancelException, SparkException}
import org.apache.spark.internal.{MDC, MessageWithContext}
import org.apache.spark.internal.LogKeys._
import org.apache.spark.rdd.RDD
Expand Down Expand Up @@ -81,8 +80,12 @@ case class AdaptiveSparkPlanExec(

@transient private val planChangeLogger = new PlanChangeLogger[SparkPlan]()

@transient private val stagesToCancel:
collection.mutable.Map[Int, (String, ExchangeQueryStageExec)] =
new collection.mutable.HashMap[Int, (String, ExchangeQueryStageExec)]()

// The logical plan optimizer for re-optimizing the current logical plan.
@transient private val optimizer = new AQEOptimizer(conf,
@transient private val optimizer = new AQEOptimizer(conf, stagesToCancel,
context.session.sessionState.adaptiveRulesHolder.runtimeOptimizerRules)

// `EnsureRequirements` may remove user-specified repartition and assume the query plan won't
Expand Down Expand Up @@ -309,7 +312,12 @@ case class AdaptiveSparkPlanExec(
}
events.offer(StageSuccess(stage, res.get))
} else {
events.offer(StageFailure(stage, res.failed.get))
res.failed.get match {
// There is no need to trigger a new round to reOptimize
case _: SparkAQEStageCancelException => // ignore
case err: Throwable =>
events.offer(StageFailure(stage, err))
}
}
// explicitly clean up the resources in this stage
stage.cleanupResources()
Expand Down Expand Up @@ -367,7 +375,15 @@ case class AdaptiveSparkPlanExec(
currentPhysicalPlan = newPhysicalPlan
currentLogicalPlan = newLogicalPlan
stagesToReplace = Seq.empty[QueryStageExec]

stagesToCancel.values.foreach(reasonAndStage => {
if (!reasonAndStage._2.isCancelled) {
reasonAndStage._2.cancel(reasonAndStage._1, quiet = true)
context.stageCache.remove(reasonAndStage._2.plan.canonicalized)
}
})
}
stagesToCancel.clear()
}
}
// Now that some stages have finished, we can try creating new stages.
Expand Down
Loading