Skip to content

Commit 1ac23de

Browse files
dagrawal3409cloud-fan
authored andcommitted
[SPARK-32613][CORE] Fix regressions in DecommissionWorkerSuite
### What changes were proposed in this pull request? The DecommissionWorkerSuite started becoming flaky and it revealed a real regression. Recently closed #29211 necessitates remembering the decommissioning shortly beyond the removal of the executor. In addition to fixing this issue, ensure that DecommissionWorkerSuite continues to pass when executors haven't had a chance to exit eagery. That is the old behavior before #29211 also still works. Added some more tests to TaskSchedulerImpl to ensure that the decommissioning information is indeed purged after a timeout. Hardened the test DecommissionWorkerSuite to make it wait for successful job completion. ### Why are the changes needed? First, let me describe the intended behavior of decommissioning: If a fetch failure happens where the source executor was decommissioned, we want to treat that as an eager signal to clear all shuffle state associated with that executor. In addition if we know that the host was decommissioned, we want to forget about all map statuses from all other executors on that decommissioned host. This is what the test "decommission workers ensure that fetch failures lead to rerun" is trying to test. This invariant is important to ensure that decommissioning a host does not lead to multiple fetch failures that might fail the job. This fetch failure can happen before the executor is truly marked "lost" because of heartbeat delays. - However, #29211 eagerly exits the executors when they are done decommissioning. This removal of the executor was racing with the fetch failure. By the time the fetch failure is triggered the executor is already removed and thus has forgotten its decommissioning information. (I tested this by delaying the decommissioning). The fix is to keep the decommissioning information around for some time after removal with some extra logic to finally purge it after a timeout. - In addition the executor loss can also bump up `shuffleFileLostEpoch` (added in #28848). This happens because when the executor is lost, it forgets the shuffle state about just that executor and increments the `shuffleFileLostEpoch`. This incrementing precludes the clearing of state of the entire host when the fetch failure happens because the failed task is still reusing the old epoch. The fix here is also simple: Ignore the `shuffleFileLostEpoch` when the shuffle status is being cleared due to a fetch failure resulting from host decommission. I am strategically making both of these fixes be very local to decommissioning to avoid other regressions. Especially the version stuff is tricky (it hasn't been fundamentally changed since it was first introduced in 2013). ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Manually ran DecommissionWorkerSuite several times using a script and ensured it all passed. ### (Internal) Configs added I added two configs, one of which is sort of meant for testing only: - `spark.test.executor.decommission.initial.sleep.millis`: Initial delay by the decommissioner shutdown thread. Default is same as before of 1 second. This is used for testing only. This one is kept "hidden" (ie not added as a constant to avoid config bloat) - `spark.executor.decommission.removed.infoCacheTTL`: Number of seconds to keep the removed executors decom entries around. It defaults to 5 minutes. It should be around the average time it takes for all of the shuffle data to be fetched from the mapper to the reducer, but I think that can take a while since the reducers also do a multistep sort. Closes #29422 from agrawaldevesh/decom_fixes. Authored-by: Devesh Agrawal <devesh.agrawal@gmail.com> Signed-off-by: Wenchen Fan <wenchen@databricks.com>
1 parent b33066f commit 1ac23de

File tree

6 files changed

+153
-37
lines changed

6 files changed

+153
-37
lines changed

core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -294,10 +294,15 @@ private[spark] class CoarseGrainedExecutorBackend(
294294
override def run(): Unit = {
295295
var lastTaskRunningTime = System.nanoTime()
296296
val sleep_time = 1000 // 1s
297-
297+
// This config is internal and only used by unit tests to force an executor
298+
// to hang around for longer when decommissioned.
299+
val initialSleepMillis = env.conf.getInt(
300+
"spark.test.executor.decommission.initial.sleep.millis", sleep_time)
301+
if (initialSleepMillis > 0) {
302+
Thread.sleep(initialSleepMillis)
303+
}
298304
while (true) {
299305
logInfo("Checking to see if we can shutdown.")
300-
Thread.sleep(sleep_time)
301306
if (executor == null || executor.numRunningTasks == 0) {
302307
if (env.conf.get(STORAGE_DECOMMISSION_ENABLED)) {
303308
logInfo("No running tasks, checking migrations")
@@ -323,6 +328,7 @@ private[spark] class CoarseGrainedExecutorBackend(
323328
// move forward.
324329
lastTaskRunningTime = System.nanoTime()
325330
}
331+
Thread.sleep(sleep_time)
326332
}
327333
}
328334
}

core/src/main/scala/org/apache/spark/internal/config/package.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1877,6 +1877,16 @@ package object config {
18771877
.timeConf(TimeUnit.SECONDS)
18781878
.createOptional
18791879

1880+
private[spark] val DECOMMISSIONED_EXECUTORS_REMEMBER_AFTER_REMOVAL_TTL =
1881+
ConfigBuilder("spark.executor.decommission.removed.infoCacheTTL")
1882+
.doc("Duration for which a decommissioned executor's information will be kept after its" +
1883+
"removal. Keeping the decommissioned info after removal helps pinpoint fetch failures to " +
1884+
"decommissioning even after the mapper executor has been decommissioned. This allows " +
1885+
"eager recovery from fetch failures caused by decommissioning, increasing job robustness.")
1886+
.version("3.1.0")
1887+
.timeConf(TimeUnit.SECONDS)
1888+
.createWithDefaultString("5m")
1889+
18801890
private[spark] val STAGING_DIR = ConfigBuilder("spark.yarn.stagingDir")
18811891
.doc("Staging directory used while submitting applications.")
18821892
.version("2.0.0")

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1846,7 +1846,14 @@ private[spark] class DAGScheduler(
18461846
execId = bmAddress.executorId,
18471847
fileLost = true,
18481848
hostToUnregisterOutputs = hostToUnregisterOutputs,
1849-
maybeEpoch = Some(task.epoch))
1849+
maybeEpoch = Some(task.epoch),
1850+
// shuffleFileLostEpoch is ignored when a host is decommissioned because some
1851+
// decommissioned executors on that host might have been removed before this fetch
1852+
// failure and might have bumped up the shuffleFileLostEpoch. We ignore that, and
1853+
// proceed with unconditional removal of shuffle outputs from all executors on that
1854+
// host, including from those that we still haven't confirmed as lost due to heartbeat
1855+
// delays.
1856+
ignoreShuffleFileLostEpoch = isHostDecommissioned)
18501857
}
18511858
}
18521859

@@ -2012,7 +2019,8 @@ private[spark] class DAGScheduler(
20122019
execId: String,
20132020
fileLost: Boolean,
20142021
hostToUnregisterOutputs: Option[String],
2015-
maybeEpoch: Option[Long] = None): Unit = {
2022+
maybeEpoch: Option[Long] = None,
2023+
ignoreShuffleFileLostEpoch: Boolean = false): Unit = {
20162024
val currentEpoch = maybeEpoch.getOrElse(mapOutputTracker.getEpoch)
20172025
logDebug(s"Considering removal of executor $execId; " +
20182026
s"fileLost: $fileLost, currentEpoch: $currentEpoch")
@@ -2022,16 +2030,25 @@ private[spark] class DAGScheduler(
20222030
blockManagerMaster.removeExecutor(execId)
20232031
clearCacheLocs()
20242032
}
2025-
if (fileLost &&
2026-
(!shuffleFileLostEpoch.contains(execId) || shuffleFileLostEpoch(execId) < currentEpoch)) {
2027-
shuffleFileLostEpoch(execId) = currentEpoch
2028-
hostToUnregisterOutputs match {
2029-
case Some(host) =>
2030-
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
2031-
mapOutputTracker.removeOutputsOnHost(host)
2032-
case None =>
2033-
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
2034-
mapOutputTracker.removeOutputsOnExecutor(execId)
2033+
if (fileLost) {
2034+
val remove = if (ignoreShuffleFileLostEpoch) {
2035+
true
2036+
} else if (!shuffleFileLostEpoch.contains(execId) ||
2037+
shuffleFileLostEpoch(execId) < currentEpoch) {
2038+
shuffleFileLostEpoch(execId) = currentEpoch
2039+
true
2040+
} else {
2041+
false
2042+
}
2043+
if (remove) {
2044+
hostToUnregisterOutputs match {
2045+
case Some(host) =>
2046+
logInfo(s"Shuffle files lost for host: $host (epoch $currentEpoch)")
2047+
mapOutputTracker.removeOutputsOnHost(host)
2048+
case None =>
2049+
logInfo(s"Shuffle files lost for executor: $execId (epoch $currentEpoch)")
2050+
mapOutputTracker.removeOutputsOnExecutor(execId)
2051+
}
20352052
}
20362053
}
20372054
}

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 25 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ import scala.collection.mutable
2626
import scala.collection.mutable.{ArrayBuffer, Buffer, HashMap, HashSet}
2727
import scala.util.Random
2828

29+
import com.google.common.base.Ticker
30+
import com.google.common.cache.CacheBuilder
31+
2932
import org.apache.spark._
3033
import org.apache.spark.TaskState.TaskState
3134
import org.apache.spark.executor.ExecutorMetrics
@@ -136,7 +139,21 @@ private[spark] class TaskSchedulerImpl(
136139
// IDs of the tasks running on each executor
137140
private val executorIdToRunningTaskIds = new HashMap[String, HashSet[Long]]
138141

139-
private val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
142+
// We add executors here when we first get decommission notification for them. Executors can
143+
// continue to run even after being asked to decommission, but they will eventually exit.
144+
val executorsPendingDecommission = new HashMap[String, ExecutorDecommissionInfo]
145+
146+
// When they exit and we know of that via heartbeat failure, we will add them to this cache.
147+
// This cache is consulted to know if a fetch failure is because a source executor was
148+
// decommissioned.
149+
lazy val decommissionedExecutorsRemoved = CacheBuilder.newBuilder()
150+
.expireAfterWrite(
151+
conf.get(DECOMMISSIONED_EXECUTORS_REMEMBER_AFTER_REMOVAL_TTL), TimeUnit.SECONDS)
152+
.ticker(new Ticker{
153+
override def read(): Long = TimeUnit.MILLISECONDS.toNanos(clock.getTimeMillis())
154+
})
155+
.build[String, ExecutorDecommissionInfo]()
156+
.asMap()
140157

141158
def runningTasksByExecutors: Map[String, Int] = synchronized {
142159
executorIdToRunningTaskIds.toMap.mapValues(_.size).toMap
@@ -910,7 +927,7 @@ private[spark] class TaskSchedulerImpl(
910927
// if we heard isHostDecommissioned ever true, then we keep that one since it is
911928
// most likely coming from the cluster manager and thus authoritative
912929
val oldDecomInfo = executorsPendingDecommission.get(executorId)
913-
if (oldDecomInfo.isEmpty || !oldDecomInfo.get.isHostDecommissioned) {
930+
if (!oldDecomInfo.exists(_.isHostDecommissioned)) {
914931
executorsPendingDecommission(executorId) = decommissionInfo
915932
}
916933
}
@@ -921,7 +938,9 @@ private[spark] class TaskSchedulerImpl(
921938

922939
override def getExecutorDecommissionInfo(executorId: String)
923940
: Option[ExecutorDecommissionInfo] = synchronized {
924-
executorsPendingDecommission.get(executorId)
941+
executorsPendingDecommission
942+
.get(executorId)
943+
.orElse(Option(decommissionedExecutorsRemoved.get(executorId)))
925944
}
926945

927946
override def executorLost(executorId: String, givenReason: ExecutorLossReason): Unit = {
@@ -1027,7 +1046,9 @@ private[spark] class TaskSchedulerImpl(
10271046
}
10281047
}
10291048

1030-
executorsPendingDecommission -= executorId
1049+
1050+
val decomInfo = executorsPendingDecommission.remove(executorId)
1051+
decomInfo.foreach(decommissionedExecutorsRemoved.put(executorId, _))
10311052

10321053
if (reason != LossReasonPending) {
10331054
executorIdToHost -= executorId

core/src/test/scala/org/apache/spark/deploy/DecommissionWorkerSuite.scala

Lines changed: 53 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,19 @@ class DecommissionWorkerSuite
8484
}
8585
}
8686

87+
// Unlike TestUtils.withListener, it also waits for the job to be done
88+
def withListener(sc: SparkContext, listener: RootStageAwareListener)
89+
(body: SparkListener => Unit): Unit = {
90+
sc.addSparkListener(listener)
91+
try {
92+
body(listener)
93+
sc.listenerBus.waitUntilEmpty()
94+
listener.waitForJobDone()
95+
} finally {
96+
sc.listenerBus.removeListener(listener)
97+
}
98+
}
99+
87100
test("decommission workers should not result in job failure") {
88101
val maxTaskFailures = 2
89102
val numTimesToKillWorkers = maxTaskFailures + 1
@@ -109,7 +122,7 @@ class DecommissionWorkerSuite
109122
}
110123
}
111124
}
112-
TestUtils.withListener(sc, listener) { _ =>
125+
withListener(sc, listener) { _ =>
113126
val jobResult = sc.parallelize(1 to 1, 1).map { _ =>
114127
Thread.sleep(5 * 1000L); 1
115128
}.count()
@@ -164,7 +177,7 @@ class DecommissionWorkerSuite
164177
}
165178
}
166179
}
167-
TestUtils.withListener(sc, listener) { _ =>
180+
withListener(sc, listener) { _ =>
168181
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((pid, _) => {
169182
val sleepTimeSeconds = if (pid == 0) 1 else 10
170183
Thread.sleep(sleepTimeSeconds * 1000L)
@@ -190,10 +203,11 @@ class DecommissionWorkerSuite
190203
}
191204
}
192205

193-
test("decommission workers ensure that fetch failures lead to rerun") {
206+
def testFetchFailures(initialSleepMillis: Int): Unit = {
194207
createWorkers(2)
195208
sc = createSparkContext(
196209
config.Tests.TEST_NO_STAGE_RETRY.key -> "false",
210+
"spark.test.executor.decommission.initial.sleep.millis" -> initialSleepMillis.toString,
197211
config.UNREGISTER_OUTPUT_ON_HOST_ON_FETCH_FAILURE.key -> "true")
198212
val executorIdToWorkerInfo = getExecutorToWorkerAssignments
199213
val executorToDecom = executorIdToWorkerInfo.keysIterator.next
@@ -212,22 +226,29 @@ class DecommissionWorkerSuite
212226
override def handleRootTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
213227
val taskInfo = taskEnd.taskInfo
214228
if (taskInfo.executorId == executorToDecom && taskInfo.attemptNumber == 0 &&
215-
taskEnd.stageAttemptId == 0) {
229+
taskEnd.stageAttemptId == 0 && taskEnd.stageId == 0) {
216230
decommissionWorkerOnMaster(workerToDecom,
217231
"decommission worker after task on it is done")
218232
}
219233
}
220234
}
221-
TestUtils.withListener(sc, listener) { _ =>
235+
withListener(sc, listener) { _ =>
222236
val jobResult = sc.parallelize(1 to 2, 2).mapPartitionsWithIndex((_, _) => {
223237
val executorId = SparkEnv.get.executorId
224-
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
225-
Thread.sleep(sleepTimeSeconds * 1000L)
238+
val context = TaskContext.get()
239+
// Only sleep in the first attempt to create the required window for decommissioning.
240+
// Subsequent attempts don't need to be delayed to speed up the test.
241+
if (context.attemptNumber() == 0 && context.stageAttemptNumber() == 0) {
242+
val sleepTimeSeconds = if (executorId == executorToDecom) 10 else 1
243+
Thread.sleep(sleepTimeSeconds * 1000L)
244+
}
226245
List(1).iterator
227246
}, preservesPartitioning = true)
228247
.repartition(1).mapPartitions(iter => {
229248
val context = TaskContext.get()
230249
if (context.attemptNumber == 0 && context.stageAttemptNumber() == 0) {
250+
// Wait a bit for the decommissioning to be triggered in the listener
251+
Thread.sleep(5000)
231252
// MapIndex is explicitly -1 to force the entire host to be decommissioned
232253
// However, this will cause both the tasks in the preceding stage since the host here is
233254
// "localhost" (shortcoming of this single-machine unit test in that all the workers
@@ -246,6 +267,14 @@ class DecommissionWorkerSuite
246267
assert(tasksSeen.size === 6, s"Expected 6 tasks but got $tasksSeen")
247268
}
248269

270+
test("decommission stalled workers ensure that fetch failures lead to rerun") {
271+
testFetchFailures(3600 * 1000)
272+
}
273+
274+
test("decommission eager workers ensure that fetch failures lead to rerun") {
275+
testFetchFailures(0)
276+
}
277+
249278
private abstract class RootStageAwareListener extends SparkListener {
250279
private var rootStageId: Option[Int] = None
251280
private val tasksFinished = new ConcurrentLinkedQueue[String]()
@@ -265,23 +294,31 @@ class DecommissionWorkerSuite
265294
override def onJobEnd(jobEnd: SparkListenerJobEnd): Unit = {
266295
jobEnd.jobResult match {
267296
case JobSucceeded => jobDone.set(true)
297+
case JobFailed(exception) => logError(s"Job failed", exception)
268298
}
269299
}
270300

271301
protected def handleRootTaskEnd(end: SparkListenerTaskEnd) = {}
272302

273303
protected def handleRootTaskStart(start: SparkListenerTaskStart) = {}
274304

305+
private def getSignature(taskInfo: TaskInfo, stageId: Int, stageAttemptId: Int):
306+
String = {
307+
s"${stageId}:${stageAttemptId}:" +
308+
s"${taskInfo.index}:${taskInfo.attemptNumber}-${taskInfo.status}"
309+
}
310+
275311
override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
312+
val signature = getSignature(taskStart.taskInfo, taskStart.stageId, taskStart.stageAttemptId)
313+
logInfo(s"Task started: $signature")
276314
if (isRootStageId(taskStart.stageId)) {
277315
rootTasksStarted.add(taskStart.taskInfo)
278316
handleRootTaskStart(taskStart)
279317
}
280318
}
281319

282320
override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
283-
val taskSignature = s"${taskEnd.stageId}:${taskEnd.stageAttemptId}:" +
284-
s"${taskEnd.taskInfo.index}:${taskEnd.taskInfo.attemptNumber}"
321+
val taskSignature = getSignature(taskEnd.taskInfo, taskEnd.stageId, taskEnd.stageAttemptId)
285322
logInfo(s"Task End $taskSignature")
286323
tasksFinished.add(taskSignature)
287324
if (isRootStageId(taskEnd.stageId)) {
@@ -291,8 +328,13 @@ class DecommissionWorkerSuite
291328
}
292329

293330
def getTasksFinished(): Seq[String] = {
294-
assert(jobDone.get(), "Job isn't successfully done yet")
295-
tasksFinished.asScala.toSeq
331+
tasksFinished.asScala.toList
332+
}
333+
334+
def waitForJobDone(): Unit = {
335+
eventually(timeout(10.seconds), interval(100.milliseconds)) {
336+
assert(jobDone.get(), "Job isn't successfully done yet")
337+
}
296338
}
297339
}
298340

0 commit comments

Comments
 (0)