Skip to content

Commit 61ac7b8

Browse files
committed
Consolidate state decommissioning in the TaskSchedulerImpl realm
1 parent b33066f commit 61ac7b8

File tree

9 files changed

+71
-52
lines changed

9 files changed

+71
-52
lines changed

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1825,7 +1825,7 @@ private[spark] class DAGScheduler(
18251825
if (bmAddress != null) {
18261826
val externalShuffleServiceEnabled = env.blockManager.externalShuffleServiceEnabled
18271827
val isHostDecommissioned = taskScheduler
1828-
.getExecutorDecommissionInfo(bmAddress.executorId)
1828+
.getExecutorDecommissionState(bmAddress.executorId)
18291829
.exists(_.isHostDecommissioned)
18301830

18311831
// Shuffle output of all executors on host `bmAddress.host` may be lost if:

core/src/main/scala/org/apache/spark/scheduler/ExecutorDecommissionInfo.scala

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,21 @@
1818
package org.apache.spark.scheduler
1919

2020
/**
21-
* Provides more detail when an executor is being decommissioned.
21+
* Message providing more detail when an executor is being decommissioned.
2222
* @param message Human readable reason for why the decommissioning is happening.
2323
* @param isHostDecommissioned Whether the host (aka the `node` or `worker` in other places) is
2424
* being decommissioned too. Used to infer if the shuffle data might
2525
* be lost even if the external shuffle service is enabled.
2626
*/
2727
private[spark]
2828
case class ExecutorDecommissionInfo(message: String, isHostDecommissioned: Boolean)
29+
30+
/**
31+
* State related to decommissioning that is kept by the TaskSchedulerImpl. This state is derived
32+
* from the info message above but it is kept distinct to allow the state to evolve independently
33+
* from the message.
34+
*/
35+
case class ExecutorDecommissionState(message: String,
36+
// Timestamp in milliseconds when decommissioning was triggered
37+
tsMillis: Long,
38+
isHostDecommissioned: Boolean)

core/src/main/scala/org/apache/spark/scheduler/TaskScheduler.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ private[spark] trait TaskScheduler {
106106
/**
107107
* If an executor is decommissioned, return its corresponding decommission info
108108
*/
109-
def getExecutorDecommissionInfo(executorId: String): Option[ExecutorDecommissionInfo]
109+
def getExecutorDecommissionState(executorId: String): Option[ExecutorDecommissionState]
110110

111111
/**
112112
* Process a lost executor

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ private[spark] class TaskSchedulerImpl(
136136
// IDs of the tasks running on each executor
137137
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
138138

139-
private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
139+
private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionState]
140140

141141
def runningTasksByExecutors: Map[String, Int] = synchronized {
142142
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
@@ -276,7 +276,7 @@ private[spark] class TaskSchedulerImpl(
276276
private[scheduler] def createTaskSetManager(
277277
taskSet: TaskSet,
278278
maxTaskFailures: Int): TaskSetManager = {
279-
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt)
279+
new TaskSetManager(this, taskSet, maxTaskFailures, blacklistTrackerOpt, clock)
280280
}
281281

282282
override def cancelTasks(stageId: Int, interruptThread: Boolean): Unit = synchronized {
@@ -911,16 +911,19 @@ private[spark] class TaskSchedulerImpl(
911911
// most likely coming from the cluster manager and thus authoritative
912912
val oldDecomInfo = executorsPendingDecommission.get(executorId)
913913
if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
914-
executorsPendingDecommission(executorId) = decommissionInfo
914+
executorsPendingDecommission(executorId) = ExecutorDecommissionState(
915+
decommissionInfo.message,
916+
oldDecomInfo.map(_.tsMillis).getOrElse(clock.getTimeMillis()),
917+
decommissionInfo.isHostDecommissioned)
915918
}
916919
}
917920
}
918921
rootPool.executorDecommission(executorId)
919922
backend.reviveOffers()
920923
}
921924

922-
override def getExecutorDecommissionInfo(executorId: String)
923-
: Option[ExecutorDecommissionInfo] = synchronized {
925+
override def getExecutorDecommissionState(executorId: String)
926+
: Option[ExecutorDecommissionState] = synchronized {
924927
executorsPendingDecommission.get(executorId)
925928
}
926929

@@ -929,7 +932,7 @@ private[spark] class TaskSchedulerImpl(
929932
val reason = givenReason match {
930933
// Handle executor process loss due to decommissioning
931934
case ExecutorProcessLost(message, origWorkerLost, origCausedByApp) =>
932-
val executorDecommissionInfo = getExecutorDecommissionInfo(executorId)
935+
val executorDecommissionInfo = getExecutorDecommissionState(executorId)
933936
ExecutorProcessLost(
934937
message,
935938
// Also mark the worker lost if we know that the host was decommissioned

core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,6 @@ private[spark] class TaskSetManager(
168168

169169
// Task index, start and finish time for each task attempt (indexed by task ID)
170170
private[scheduler] val taskInfos = new HashMap[Long, TaskInfo]
171-
private[scheduler] val tidToExecutorKillTimeMapping = new HashMap[Long, Long]
172171

173172
// Use a MedianHeap to record durations of successful tasks so we know when to launch
174173
// speculative tasks. This is only used when speculation is enabled, to avoid the overhead
@@ -939,7 +938,6 @@ private[spark] class TaskSetManager(
939938

940939
/** If the given task ID is in the set of running tasks, removes it. */
941940
def removeRunningTask(tid: Long): Unit = {
942-
tidToExecutorKillTimeMapping.remove(tid)
943941
if (runningTasksSet.remove(tid) && parent != null) {
944942
parent.decreaseRunningTasks(1)
945943
}
@@ -1050,15 +1048,19 @@ private[spark] class TaskSetManager(
10501048
logDebug("Task length threshold for speculation: " + threshold)
10511049
for (tid <- runningTasksSet) {
10521050
var speculated = checkAndSubmitSpeculatableTask(tid, time, threshold)
1053-
if (!speculated && tidToExecutorKillTimeMapping.contains(tid)) {
1054-
// Check whether this task will finish before the exectorKillTime assuming
1055-
// it will take medianDuration overall. If this task cannot finish within
1056-
// executorKillInterval, then this task is a candidate for speculation
1057-
val taskEndTimeBasedOnMedianDuration = taskInfos(tid).launchTime + medianDuration
1058-
val canExceedDeadline = tidToExecutorKillTimeMapping(tid) <
1059-
taskEndTimeBasedOnMedianDuration
1060-
if (canExceedDeadline) {
1061-
speculated = checkAndSubmitSpeculatableTask(tid, time, 0)
1051+
if (!speculated && executorDecommissionKillInterval.nonEmpty) {
1052+
val taskInfo = taskInfos(tid)
1053+
val decomInfo = sched.getExecutorDecommissionState(taskInfo.executorId)
1054+
if (decomInfo.nonEmpty) {
1055+
// Check whether this task will finish before the exectorKillTime assuming
1056+
// it will take medianDuration overall. If this task cannot finish within
1057+
// executorKillInterval, then this task is a candidate for speculation
1058+
val taskEndTimeBasedOnMedianDuration = taskInfos(tid).launchTime + medianDuration
1059+
val executorDecomTime = decomInfo.get.tsMillis + executorDecommissionKillInterval.get
1060+
val canExceedDeadline = executorDecomTime < taskEndTimeBasedOnMedianDuration
1061+
if (canExceedDeadline) {
1062+
speculated = checkAndSubmitSpeculatableTask(tid, time, 0)
1063+
}
10621064
}
10631065
}
10641066
foundTasks |= speculated
@@ -1119,12 +1121,6 @@ private[spark] class TaskSetManager(
11191121

11201122
def executorDecommission(execId: String): Unit = {
11211123
recomputeLocality()
1122-
executorDecommissionKillInterval.foreach { interval =>
1123-
val executorKillTime = clock.getTimeMillis() + interval
1124-
runningTasksSet.filter(taskInfos(_).executorId == execId).foreach { tid =>
1125-
tidToExecutorKillTimeMapping(tid) = executorKillTime
1126-
}
1127-
}
11281124
}
11291125

11301126
def recomputeLocality(): Unit = {

core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,8 +178,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
178178
override def executorDecommission(
179179
executorId: String,
180180
decommissionInfo: ExecutorDecommissionInfo): Unit = {}
181-
override def getExecutorDecommissionInfo(
182-
executorId: String): Option[ExecutorDecommissionInfo] = None
181+
override def getExecutorDecommissionState(
182+
executorId: String): Option[ExecutorDecommissionState] = None
183183
}
184184

185185
/**
@@ -787,8 +787,8 @@ class DAGSchedulerSuite extends SparkFunSuite with LocalSparkContext with TimeLi
787787
override def executorDecommission(
788788
executorId: String,
789789
decommissionInfo: ExecutorDecommissionInfo): Unit = {}
790-
override def getExecutorDecommissionInfo(
791-
executorId: String): Option[ExecutorDecommissionInfo] = None
790+
override def getExecutorDecommissionState(
791+
executorId: String): Option[ExecutorDecommissionState] = None
792792
}
793793
val noKillScheduler = new DAGScheduler(
794794
sc,

core/src/test/scala/org/apache/spark/scheduler/ExternalClusterManagerSuite.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,6 @@ private class DummyTaskScheduler extends TaskScheduler {
101101
override def executorDecommission(
102102
executorId: String,
103103
decommissionInfo: ExecutorDecommissionInfo): Unit = {}
104-
override def getExecutorDecommissionInfo(
105-
executorId: String): Option[ExecutorDecommissionInfo] = None
104+
override def getExecutorDecommissionState(
105+
executorId: String): Option[ExecutorDecommissionState] = None
106106
}

core/src/test/scala/org/apache/spark/scheduler/TaskSchedulerImplSuite.scala

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1822,31 +1822,34 @@ class TaskSchedulerImplSuite extends SparkFunSuite with LocalSparkContext with B
18221822
scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("0 new", false))
18231823
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("1 new", false))
18241824

1825-
assert(scheduler.getExecutorDecommissionInfo("executor0")
1825+
def convert(state: Option[ExecutorDecommissionState]): Option[ExecutorDecommissionInfo] =
1826+
state.map(s => ExecutorDecommissionInfo(s.message, s.isHostDecommissioned))
1827+
1828+
assert(convert(scheduler.getExecutorDecommissionState("executor0"))
18261829
=== Some(ExecutorDecommissionInfo("0 new", false)))
1827-
assert(scheduler.getExecutorDecommissionInfo("executor1")
1830+
assert(convert(scheduler.getExecutorDecommissionState("executor1"))
18281831
=== Some(ExecutorDecommissionInfo("1", true)))
1829-
assert(scheduler.getExecutorDecommissionInfo("executor2").isEmpty)
1832+
assert(scheduler.getExecutorDecommissionState("executor2").isEmpty)
18301833
}
18311834

18321835
test("scheduler should ignore decommissioning of removed executors") {
18331836
val scheduler = setupSchedulerForDecommissionTests()
18341837

18351838
// executor 0 is decommissioned after loosing
1836-
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
1839+
assert(scheduler.getExecutorDecommissionState("executor0").isEmpty)
18371840
scheduler.executorLost("executor0", ExecutorExited(0, false, "normal"))
1838-
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
1841+
assert(scheduler.getExecutorDecommissionState("executor0").isEmpty)
18391842
scheduler.executorDecommission("executor0", ExecutorDecommissionInfo("", false))
1840-
assert(scheduler.getExecutorDecommissionInfo("executor0").isEmpty)
1843+
assert(scheduler.getExecutorDecommissionState("executor0").isEmpty)
18411844

18421845
// executor 1 is decommissioned before loosing
1843-
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
1846+
assert(scheduler.getExecutorDecommissionState("executor1").isEmpty)
18441847
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
1845-
assert(scheduler.getExecutorDecommissionInfo("executor1").isDefined)
1848+
assert(scheduler.getExecutorDecommissionState("executor1").isDefined)
18461849
scheduler.executorLost("executor1", ExecutorExited(0, false, "normal"))
1847-
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
1850+
assert(scheduler.getExecutorDecommissionState("executor1").isEmpty)
18481851
scheduler.executorDecommission("executor1", ExecutorDecommissionInfo("", false))
1849-
assert(scheduler.getExecutorDecommissionInfo("executor1").isEmpty)
1852+
assert(scheduler.getExecutorDecommissionState("executor1").isEmpty)
18501853
}
18511854

18521855
/**

core/src/test/scala/org/apache/spark/scheduler/TaskSetManagerSuite.scala

Lines changed: 17 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ import org.apache.spark.resource.TestResourceIDs._
4141
import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
4242
import org.apache.spark.serializer.SerializerInstance
4343
import org.apache.spark.storage.BlockManagerId
44-
import org.apache.spark.util.{AccumulatorV2, ManualClock}
44+
import org.apache.spark.util.{AccumulatorV2, Clock, ManualClock, SystemClock}
4545

4646
class FakeDAGScheduler(sc: SparkContext, taskScheduler: FakeTaskScheduler)
4747
extends DAGScheduler(sc) {
@@ -109,8 +109,9 @@ object FakeRackUtil {
109109
* a list of "live" executors and their hostnames for isExecutorAlive and hasExecutorsAliveOnHost
110110
* to work, and these are required for locality in TaskSetManager.
111111
*/
112-
class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* execId, host */)
113-
extends TaskSchedulerImpl(sc)
112+
class FakeTaskScheduler(sc: SparkContext, clock: Clock,
113+
liveExecutors: (String, String)* /* execId, host */)
114+
extends TaskSchedulerImpl(sc, sc.conf.get(config.TASK_MAX_FAILURES), clock = clock)
114115
{
115116
val startedTasks = new ArrayBuffer[Long]
116117
val endedTasks = new mutable.HashMap[Long, TaskEndReason]
@@ -120,6 +121,10 @@ class FakeTaskScheduler(sc: SparkContext, liveExecutors: (String, String)* /* ex
120121

121122
val executors = new mutable.HashMap[String, String]
122123

124+
def this(sc: SparkContext, liveExecutors: (String, String)*) {
125+
this(sc, new SystemClock, liveExecutors: _*)
126+
}
127+
123128
// this must be initialized before addExecutor
124129
override val defaultRackValue: Option[String] = Some("default")
125130
for ((execId, host) <- liveExecutors) {
@@ -1922,14 +1927,16 @@ class TaskSetManagerSuite
19221927
test("SPARK-21040: Check speculative tasks are launched when an executor is decommissioned" +
19231928
" and the tasks running on it cannot finish within EXECUTOR_DECOMMISSION_KILL_INTERVAL") {
19241929
sc = new SparkContext("local", "test")
1925-
sched = new FakeTaskScheduler(sc, ("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
1930+
val clock = new ManualClock()
1931+
sched = new FakeTaskScheduler(sc, clock,
1932+
("exec1", "host1"), ("exec2", "host2"), ("exec3", "host3"))
1933+
sched.backend = mock(classOf[SchedulerBackend])
19261934
val taskSet = FakeTask.createTaskSet(4)
19271935
sc.conf.set(config.SPECULATION_ENABLED, true)
19281936
sc.conf.set(config.SPECULATION_MULTIPLIER, 1.5)
19291937
sc.conf.set(config.SPECULATION_QUANTILE, 0.5)
19301938
sc.conf.set(config.EXECUTOR_DECOMMISSION_KILL_INTERVAL.key, "5s")
1931-
val clock = new ManualClock()
1932-
val manager = new TaskSetManager(sched, taskSet, MAX_TASK_FAILURES, clock = clock)
1939+
val manager = sched.createTaskSetManager(taskSet, MAX_TASK_FAILURES)
19331940
val accumUpdatesByTask: Array[Seq[AccumulatorV2[_, _]]] = taskSet.tasks.map { task =>
19341941
task.metrics.internalAccums
19351942
}
@@ -1968,10 +1975,10 @@ class TaskSetManagerSuite
19681975
// decommission exec-2. All tasks running on exec-2 (i.e. TASK 2,3) will be added to
19691976
// executorDecommissionSpeculationTriggerTimeoutOpt
19701977
// (TASK 2 -> 15, TASK 3 -> 15)
1971-
manager.executorDecommission("exec2")
1972-
assert(manager.tidToExecutorKillTimeMapping.keySet === Set(2, 3))
1973-
assert(manager.tidToExecutorKillTimeMapping(2) === 15*1000)
1974-
assert(manager.tidToExecutorKillTimeMapping(3) === 15*1000)
1978+
sched.executorDecommission("exec2", ExecutorDecommissionInfo("decom",
1979+
isHostDecommissioned = false))
1980+
assert(sched.getExecutorDecommissionState("exec2").map(_.tsMillis) ===
1981+
Some(clock.getTimeMillis()))
19751982

19761983
assert(manager.checkSpeculatableTasks(0))
19771984
// TASK 2 started at t=0s, so it can still finish before t=15s (Median task runtime = 10s)

0 commit comments

Comments
 (0)