Skip to content

Commit c8e909c

Browse files
Marcelo Vanzintgravescs
Marcelo Vanzin
authored andcommitted
[SPARK-24589][CORE] Correctly identify tasks in output commit coordinator.
When an output stage is retried, it's possible that tasks from the previous attempt are still running. In that case, there would be a new task for the same partition in the new attempt, and the coordinator would allow both tasks to commit their output since it did not keep track of stage attempts. The change adds more information to the stage state tracked by the coordinator, so that only one task is allowed to commit the output in the above case. The stage state in the coordinator is also maintained across stage retries, so that a stray speculative task from a previous stage attempt is not allowed to commit. This also removes some code added in SPARK-18113 that allowed for duplicate commit requests; with the RPC code used in Spark 2, that situation cannot happen, so there is no need to handle it. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes apache#21577 from vanzin/SPARK-24552.
1 parent b56e9c6 commit c8e909c

File tree

5 files changed

+179
-105
lines changed

5 files changed

+179
-105
lines changed

core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging {
6969

7070
if (shouldCoordinateWithDriver) {
7171
val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
72-
val taskAttemptNumber = TaskContext.get().attemptNumber()
73-
val stageId = TaskContext.get().stageId()
74-
val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber)
72+
val ctx = TaskContext.get()
73+
val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(),
74+
splitId, ctx.attemptNumber())
7575

7676
if (canCommit) {
7777
performCommit()
@@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging {
8181
logInfo(message)
8282
// We need to abort the task so that the driver can reschedule new attempts, if necessary
8383
committer.abortTask(mrTaskContext)
84-
throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber)
84+
throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber())
8585
}
8686
} else {
8787
// Speculation is disabled or a user has chosen to manually bypass the commit coordination

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,6 +1171,7 @@ class DAGScheduler(
11711171

11721172
outputCommitCoordinator.taskCompleted(
11731173
stageId,
1174+
task.stageAttemptId,
11741175
task.partitionId,
11751176
event.taskInfo.attemptNumber, // this is a task attempt number
11761177
event.reason)
@@ -1330,23 +1331,24 @@ class DAGScheduler(
13301331
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
13311332
s"(attempt ${failedStage.latestInfo.attemptNumber}) running")
13321333
} else {
1334+
failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
1335+
val shouldAbortStage =
1336+
failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
1337+
disallowStageRetryForTest
1338+
13331339
// It is likely that we receive multiple FetchFailed for a single stage (because we have
13341340
// multiple tasks running concurrently on different executors). In that case, it is
13351341
// possible the fetch failure has already been handled by the scheduler.
13361342
if (runningStages.contains(failedStage)) {
13371343
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
13381344
s"due to a fetch failure from $mapStage (${mapStage.name})")
1339-
markStageAsFinished(failedStage, Some(failureMessage))
1345+
markStageAsFinished(failedStage, errorMessage = Some(failureMessage),
1346+
willRetry = !shouldAbortStage)
13401347
} else {
13411348
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
13421349
s"longer running")
13431350
}
13441351

1345-
failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
1346-
val shouldAbortStage =
1347-
failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
1348-
disallowStageRetryForTest
1349-
13501352
if (shouldAbortStage) {
13511353
val abortMessage = if (disallowStageRetryForTest) {
13521354
"Fetch failure will not retry stage due to testing config"
@@ -1545,7 +1547,10 @@ class DAGScheduler(
15451547
/**
15461548
* Marks a stage as finished and removes it from the list of running stages.
15471549
*/
1548-
private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
1550+
private def markStageAsFinished(
1551+
stage: Stage,
1552+
errorMessage: Option[String] = None,
1553+
willRetry: Boolean = false): Unit = {
15491554
val serviceTime = stage.latestInfo.submissionTime match {
15501555
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
15511556
case _ => "Unknown"
@@ -1564,7 +1569,9 @@ class DAGScheduler(
15641569
logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}")
15651570
}
15661571

1567-
outputCommitCoordinator.stageEnd(stage.id)
1572+
if (!willRetry) {
1573+
outputCommitCoordinator.stageEnd(stage.id)
1574+
}
15681575
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
15691576
runningStages -= stage
15701577
}

core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils}
2727
private sealed trait OutputCommitCoordinationMessage extends Serializable
2828

2929
private case object StopCoordinator extends OutputCommitCoordinationMessage
30-
private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int)
30+
private case class AskPermissionToCommitOutput(
31+
stage: Int,
32+
stageAttempt: Int,
33+
partition: Int,
34+
attemptNumber: Int)
3135

3236
/**
3337
* Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
4549
// Initialized by SparkEnv
4650
var coordinatorRef: Option[RpcEndpointRef] = None
4751

48-
private type StageId = Int
49-
private type PartitionId = Int
50-
private type TaskAttemptNumber = Int
51-
private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1
52+
// Class used to identify a committer. The task ID for a committer is implicitly defined by
53+
// the partition being processed, but the coordinator needs to keep track of both the stage
54+
// attempt and the task attempt, because in some situations the same task may be running
55+
// concurrently in two different attempts of the same stage.
56+
private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int)
57+
5258
private case class StageState(numPartitions: Int) {
53-
val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER)
54-
val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]()
59+
val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null)
60+
val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]()
5561
}
5662

5763
/**
@@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
6470
*
6571
* Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
6672
*/
67-
private val stageStates = mutable.Map[StageId, StageState]()
73+
private val stageStates = mutable.Map[Int, StageState]()
6874

6975
/**
7076
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
8793
* @return true if this task is authorized to commit, false otherwise
8894
*/
8995
def canCommit(
90-
stage: StageId,
91-
partition: PartitionId,
92-
attemptNumber: TaskAttemptNumber): Boolean = {
93-
val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber)
96+
stage: Int,
97+
stageAttempt: Int,
98+
partition: Int,
99+
attemptNumber: Int): Boolean = {
100+
val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber)
94101
coordinatorRef match {
95102
case Some(endpointRef) =>
96103
ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg),
@@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
103110
}
104111

105112
/**
106-
* Called by the DAGScheduler when a stage starts.
113+
* Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't
114+
* yet been initialized.
107115
*
108116
* @param stage the stage id.
109117
* @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
110118
* the maximum possible value of `context.partitionId`).
111119
*/
112-
private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized {
113-
stageStates(stage) = new StageState(maxPartitionId + 1)
120+
private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized {
121+
stageStates.get(stage) match {
122+
case Some(state) =>
123+
require(state.authorizedCommitters.length == maxPartitionId + 1)
124+
logInfo(s"Reusing state from previous attempt of stage $stage.")
125+
126+
case _ =>
127+
stageStates(stage) = new StageState(maxPartitionId + 1)
128+
}
114129
}
115130

116131
// Called by DAGScheduler
117-
private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
132+
private[scheduler] def stageEnd(stage: Int): Unit = synchronized {
118133
stageStates.remove(stage)
119134
}
120135

121136
// Called by DAGScheduler
122137
private[scheduler] def taskCompleted(
123-
stage: StageId,
124-
partition: PartitionId,
125-
attemptNumber: TaskAttemptNumber,
138+
stage: Int,
139+
stageAttempt: Int,
140+
partition: Int,
141+
attemptNumber: Int,
126142
reason: TaskEndReason): Unit = synchronized {
127143
val stageState = stageStates.getOrElse(stage, {
128144
logDebug(s"Ignoring task completion for completed stage")
@@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
131147
reason match {
132148
case Success =>
133149
// The task output has been committed successfully
134-
case denied: TaskCommitDenied =>
135-
logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " +
136-
s"attempt: $attemptNumber")
137-
case otherReason =>
150+
case _: TaskCommitDenied =>
151+
logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " +
152+
s"partition: $partition, attempt: $attemptNumber")
153+
case _ =>
138154
// Mark the attempt as failed to blacklist from future commit protocol
139-
stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber
140-
if (stageState.authorizedCommitters(partition) == attemptNumber) {
155+
val taskId = TaskIdentifier(stageAttempt, attemptNumber)
156+
stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId
157+
if (stageState.authorizedCommitters(partition) == taskId) {
141158
logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " +
142159
s"partition=$partition) failed; clearing lock")
143-
stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
160+
stageState.authorizedCommitters(partition) = null
144161
}
145162
}
146163
}
@@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
155172

156173
// Marked private[scheduler] instead of private so this can be mocked in tests
157174
private[scheduler] def handleAskPermissionToCommit(
158-
stage: StageId,
159-
partition: PartitionId,
160-
attemptNumber: TaskAttemptNumber): Boolean = synchronized {
175+
stage: Int,
176+
stageAttempt: Int,
177+
partition: Int,
178+
attemptNumber: Int): Boolean = synchronized {
161179
stageStates.get(stage) match {
162-
case Some(state) if attemptFailed(state, partition, attemptNumber) =>
163-
logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," +
164-
s" partition=$partition as task attempt $attemptNumber has already failed.")
180+
case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) =>
181+
logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
182+
s"task attempt $attemptNumber already marked as failed.")
165183
false
166184
case Some(state) =>
167-
state.authorizedCommitters(partition) match {
168-
case NO_AUTHORIZED_COMMITTER =>
169-
logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " +
170-
s"partition=$partition")
171-
state.authorizedCommitters(partition) = attemptNumber
172-
true
173-
case existingCommitter =>
174-
// Coordinator should be idempotent when receiving AskPermissionToCommit.
175-
if (existingCommitter == attemptNumber) {
176-
logWarning(s"Authorizing duplicate request to commit for " +
177-
s"attemptNumber=$attemptNumber to commit for stage=$stage," +
178-
s" partition=$partition; existingCommitter = $existingCommitter." +
179-
s" This can indicate dropped network traffic.")
180-
true
181-
} else {
182-
logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " +
183-
s"partition=$partition; existingCommitter = $existingCommitter")
184-
false
185-
}
185+
val existing = state.authorizedCommitters(partition)
186+
if (existing == null) {
187+
logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " +
188+
s"task attempt $attemptNumber")
189+
state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber)
190+
true
191+
} else {
192+
logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
193+
s"already committed by $existing")
194+
false
186195
}
187196
case None =>
188-
logDebug(s"Stage $stage has completed, so not allowing" +
189-
s" attempt number $attemptNumber of partition $partition to commit")
197+
logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
198+
"stage already marked as completed.")
190199
false
191200
}
192201
}
193202

194203
private def attemptFailed(
195204
stageState: StageState,
196-
partition: PartitionId,
197-
attempt: TaskAttemptNumber): Boolean = synchronized {
198-
stageState.failures.get(partition).exists(_.contains(attempt))
205+
stageAttempt: Int,
206+
partition: Int,
207+
attempt: Int): Boolean = synchronized {
208+
val failInfo = TaskIdentifier(stageAttempt, attempt)
209+
stageState.failures.get(partition).exists(_.contains(failInfo))
199210
}
200211
}
201212

@@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator {
215226
}
216227

217228
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
218-
case AskPermissionToCommitOutput(stage, partition, attemptNumber) =>
229+
case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) =>
219230
context.reply(
220-
outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber))
231+
outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition,
232+
attemptNumber))
221233
}
222234
}
223235
}

0 commit comments

Comments
 (0)