Skip to content

Commit 751b008

Browse files
Marcelo Vanzintgravescs
Marcelo Vanzin
authored andcommitted
[SPARK-24589][CORE] Correctly identify tasks in output commit coordinator.
When an output stage is retried, it's possible that tasks from the previous attempt are still running. In that case, there would be a new task for the same partition in the new attempt, and the coordinator would allow both tasks to commit their output since it did not keep track of stage attempts. The change adds more information to the stage state tracked by the coordinator, so that only one task is allowed to commit the output in the above case. The stage state in the coordinator is also maintained across stage retries, so that a stray speculative task from a previous stage attempt is not allowed to commit. This also removes some code added in SPARK-18113 that allowed for duplicate commit requests; with the RPC code used in Spark 2, that situation cannot happen, so there is no need to handle it. Author: Marcelo Vanzin <vanzin@cloudera.com> Closes #21577 from vanzin/SPARK-24552. (cherry picked from commit c8e909c) Signed-off-by: Thomas Graves <tgraves@apache.org>
1 parent 7bfefc9 commit 751b008

File tree

4 files changed

+173
-102
lines changed

4 files changed

+173
-102
lines changed

core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ object SparkHadoopMapRedUtil extends Logging {
6969

7070
if (shouldCoordinateWithDriver) {
7171
val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
72-
val taskAttemptNumber = TaskContext.get().attemptNumber()
73-
val stageId = TaskContext.get().stageId()
74-
val canCommit = outputCommitCoordinator.canCommit(stageId, splitId, taskAttemptNumber)
72+
val ctx = TaskContext.get()
73+
val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(),
74+
splitId, ctx.attemptNumber())
7575

7676
if (canCommit) {
7777
performCommit()
@@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging {
8181
logInfo(message)
8282
// We need to abort the task so that the driver can reschedule new attempts, if necessary
8383
committer.abortTask(mrTaskContext)
84-
throw new CommitDeniedException(message, stageId, splitId, taskAttemptNumber)
84+
throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber())
8585
}
8686
} else {
8787
// Speculation is disabled or a user has chosen to manually bypass the commit coordination

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1151,6 +1151,7 @@ class DAGScheduler(
11511151

11521152
outputCommitCoordinator.taskCompleted(
11531153
stageId,
1154+
task.stageAttemptId,
11541155
task.partitionId,
11551156
event.taskInfo.attemptNumber, // this is a task attempt number
11561157
event.reason)
@@ -1309,23 +1310,24 @@ class DAGScheduler(
13091310
s" ${task.stageAttemptId} and there is a more recent attempt for that stage " +
13101311
s"(attempt ID ${failedStage.latestInfo.attemptId}) running")
13111312
} else {
1313+
failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
1314+
val shouldAbortStage =
1315+
failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
1316+
disallowStageRetryForTest
1317+
13121318
// It is likely that we receive multiple FetchFailed for a single stage (because we have
13131319
// multiple tasks running concurrently on different executors). In that case, it is
13141320
// possible the fetch failure has already been handled by the scheduler.
13151321
if (runningStages.contains(failedStage)) {
13161322
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
13171323
s"due to a fetch failure from $mapStage (${mapStage.name})")
1318-
markStageAsFinished(failedStage, Some(failureMessage))
1324+
markStageAsFinished(failedStage, errorMessage = Some(failureMessage),
1325+
willRetry = !shouldAbortStage)
13191326
} else {
13201327
logDebug(s"Received fetch failure from $task, but its from $failedStage which is no " +
13211328
s"longer running")
13221329
}
13231330

1324-
failedStage.fetchFailedAttemptIds.add(task.stageAttemptId)
1325-
val shouldAbortStage =
1326-
failedStage.fetchFailedAttemptIds.size >= maxConsecutiveStageAttempts ||
1327-
disallowStageRetryForTest
1328-
13291331
if (shouldAbortStage) {
13301332
val abortMessage = if (disallowStageRetryForTest) {
13311333
"Fetch failure will not retry stage due to testing config"
@@ -1471,7 +1473,10 @@ class DAGScheduler(
14711473
/**
14721474
* Marks a stage as finished and removes it from the list of running stages.
14731475
*/
1474-
private def markStageAsFinished(stage: Stage, errorMessage: Option[String] = None): Unit = {
1476+
private def markStageAsFinished(
1477+
stage: Stage,
1478+
errorMessage: Option[String] = None,
1479+
willRetry: Boolean = false): Unit = {
14751480
val serviceTime = stage.latestInfo.submissionTime match {
14761481
case Some(t) => "%.03f".format((clock.getTimeMillis() - t) / 1000.0)
14771482
case _ => "Unknown"
@@ -1490,7 +1495,9 @@ class DAGScheduler(
14901495
logInfo(s"$stage (${stage.name}) failed in $serviceTime s due to ${errorMessage.get}")
14911496
}
14921497

1493-
outputCommitCoordinator.stageEnd(stage.id)
1498+
if (!willRetry) {
1499+
outputCommitCoordinator.stageEnd(stage.id)
1500+
}
14941501
listenerBus.post(SparkListenerStageCompleted(stage.latestInfo))
14951502
runningStages -= stage
14961503
}

core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala

Lines changed: 70 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils}
2727
private sealed trait OutputCommitCoordinationMessage extends Serializable
2828

2929
private case object StopCoordinator extends OutputCommitCoordinationMessage
30-
private case class AskPermissionToCommitOutput(stage: Int, partition: Int, attemptNumber: Int)
30+
private case class AskPermissionToCommitOutput(
31+
stage: Int,
32+
stageAttempt: Int,
33+
partition: Int,
34+
attemptNumber: Int)
3135

3236
/**
3337
* Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
4549
// Initialized by SparkEnv
4650
var coordinatorRef: Option[RpcEndpointRef] = None
4751

48-
private type StageId = Int
49-
private type PartitionId = Int
50-
private type TaskAttemptNumber = Int
51-
private val NO_AUTHORIZED_COMMITTER: TaskAttemptNumber = -1
52+
// Class used to identify a committer. The task ID for a committer is implicitly defined by
53+
// the partition being processed, but the coordinator needs to keep track of both the stage
54+
// attempt and the task attempt, because in some situations the same task may be running
55+
// concurrently in two different attempts of the same stage.
56+
private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int)
57+
5258
private case class StageState(numPartitions: Int) {
53-
val authorizedCommitters = Array.fill[TaskAttemptNumber](numPartitions)(NO_AUTHORIZED_COMMITTER)
54-
val failures = mutable.Map[PartitionId, mutable.Set[TaskAttemptNumber]]()
59+
val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null)
60+
val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]()
5561
}
5662

5763
/**
@@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
6470
*
6571
* Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
6672
*/
67-
private val stageStates = mutable.Map[StageId, StageState]()
73+
private val stageStates = mutable.Map[Int, StageState]()
6874

6975
/**
7076
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
8793
* @return true if this task is authorized to commit, false otherwise
8894
*/
8995
def canCommit(
90-
stage: StageId,
91-
partition: PartitionId,
92-
attemptNumber: TaskAttemptNumber): Boolean = {
93-
val msg = AskPermissionToCommitOutput(stage, partition, attemptNumber)
96+
stage: Int,
97+
stageAttempt: Int,
98+
partition: Int,
99+
attemptNumber: Int): Boolean = {
100+
val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber)
94101
coordinatorRef match {
95102
case Some(endpointRef) =>
96103
ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg),
@@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
103110
}
104111

105112
/**
106-
* Called by the DAGScheduler when a stage starts.
113+
* Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't
114+
* yet been initialized.
107115
*
108116
* @param stage the stage id.
109117
* @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
110118
* the maximum possible value of `context.partitionId`).
111119
*/
112-
private[scheduler] def stageStart(stage: StageId, maxPartitionId: Int): Unit = synchronized {
113-
stageStates(stage) = new StageState(maxPartitionId + 1)
120+
private[scheduler] def stageStart(stage: Int, maxPartitionId: Int): Unit = synchronized {
121+
stageStates.get(stage) match {
122+
case Some(state) =>
123+
require(state.authorizedCommitters.length == maxPartitionId + 1)
124+
logInfo(s"Reusing state from previous attempt of stage $stage.")
125+
126+
case _ =>
127+
stageStates(stage) = new StageState(maxPartitionId + 1)
128+
}
114129
}
115130

116131
// Called by DAGScheduler
117-
private[scheduler] def stageEnd(stage: StageId): Unit = synchronized {
132+
private[scheduler] def stageEnd(stage: Int): Unit = synchronized {
118133
stageStates.remove(stage)
119134
}
120135

121136
// Called by DAGScheduler
122137
private[scheduler] def taskCompleted(
123-
stage: StageId,
124-
partition: PartitionId,
125-
attemptNumber: TaskAttemptNumber,
138+
stage: Int,
139+
stageAttempt: Int,
140+
partition: Int,
141+
attemptNumber: Int,
126142
reason: TaskEndReason): Unit = synchronized {
127143
val stageState = stageStates.getOrElse(stage, {
128144
logDebug(s"Ignoring task completion for completed stage")
@@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
131147
reason match {
132148
case Success =>
133149
// The task output has been committed successfully
134-
case denied: TaskCommitDenied =>
135-
logInfo(s"Task was denied committing, stage: $stage, partition: $partition, " +
136-
s"attempt: $attemptNumber")
137-
case otherReason =>
150+
case _: TaskCommitDenied =>
151+
logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " +
152+
s"partition: $partition, attempt: $attemptNumber")
153+
case _ =>
138154
// Mark the attempt as failed to blacklist from future commit protocol
139-
stageState.failures.getOrElseUpdate(partition, mutable.Set()) += attemptNumber
140-
if (stageState.authorizedCommitters(partition) == attemptNumber) {
155+
val taskId = TaskIdentifier(stageAttempt, attemptNumber)
156+
stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId
157+
if (stageState.authorizedCommitters(partition) == taskId) {
141158
logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " +
142159
s"partition=$partition) failed; clearing lock")
143-
stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
160+
stageState.authorizedCommitters(partition) = null
144161
}
145162
}
146163
}
@@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
155172

156173
// Marked private[scheduler] instead of private so this can be mocked in tests
157174
private[scheduler] def handleAskPermissionToCommit(
158-
stage: StageId,
159-
partition: PartitionId,
160-
attemptNumber: TaskAttemptNumber): Boolean = synchronized {
175+
stage: Int,
176+
stageAttempt: Int,
177+
partition: Int,
178+
attemptNumber: Int): Boolean = synchronized {
161179
stageStates.get(stage) match {
162-
case Some(state) if attemptFailed(state, partition, attemptNumber) =>
163-
logInfo(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage," +
164-
s" partition=$partition as task attempt $attemptNumber has already failed.")
180+
case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) =>
181+
logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
182+
s"task attempt $attemptNumber already marked as failed.")
165183
false
166184
case Some(state) =>
167-
state.authorizedCommitters(partition) match {
168-
case NO_AUTHORIZED_COMMITTER =>
169-
logDebug(s"Authorizing attemptNumber=$attemptNumber to commit for stage=$stage, " +
170-
s"partition=$partition")
171-
state.authorizedCommitters(partition) = attemptNumber
172-
true
173-
case existingCommitter =>
174-
// Coordinator should be idempotent when receiving AskPermissionToCommit.
175-
if (existingCommitter == attemptNumber) {
176-
logWarning(s"Authorizing duplicate request to commit for " +
177-
s"attemptNumber=$attemptNumber to commit for stage=$stage," +
178-
s" partition=$partition; existingCommitter = $existingCommitter." +
179-
s" This can indicate dropped network traffic.")
180-
true
181-
} else {
182-
logDebug(s"Denying attemptNumber=$attemptNumber to commit for stage=$stage, " +
183-
s"partition=$partition; existingCommitter = $existingCommitter")
184-
false
185-
}
185+
val existing = state.authorizedCommitters(partition)
186+
if (existing == null) {
187+
logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " +
188+
s"task attempt $attemptNumber")
189+
state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber)
190+
true
191+
} else {
192+
logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
193+
s"already committed by $existing")
194+
false
186195
}
187196
case None =>
188-
logDebug(s"Stage $stage has completed, so not allowing" +
189-
s" attempt number $attemptNumber of partition $partition to commit")
197+
logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
198+
"stage already marked as completed.")
190199
false
191200
}
192201
}
193202

194203
private def attemptFailed(
195204
stageState: StageState,
196-
partition: PartitionId,
197-
attempt: TaskAttemptNumber): Boolean = synchronized {
198-
stageState.failures.get(partition).exists(_.contains(attempt))
205+
stageAttempt: Int,
206+
partition: Int,
207+
attempt: Int): Boolean = synchronized {
208+
val failInfo = TaskIdentifier(stageAttempt, attempt)
209+
stageState.failures.get(partition).exists(_.contains(failInfo))
199210
}
200211
}
201212

@@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator {
215226
}
216227

217228
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
218-
case AskPermissionToCommitOutput(stage, partition, attemptNumber) =>
229+
case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) =>
219230
context.reply(
220-
outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber))
231+
outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition,
232+
attemptNumber))
221233
}
222234
}
223235
}

0 commit comments

Comments
 (0)