Skip to content

Commit e6a862e

Browse files
author
Marcelo Vanzin
committed
Use task ID in output committer instead of stage / task attempt numbers.
1 parent 264c533 commit e6a862e

File tree

6 files changed

+65
-91
lines changed

6 files changed

+65
-91
lines changed

core/src/main/scala/org/apache/spark/executor/CommitDeniedException.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ private[spark] class CommitDeniedException(
2626
msg: String,
2727
jobID: Int,
2828
splitID: Int,
29-
attemptNumber: Int)
29+
taskId: Long)
3030
extends Exception(msg) {
3131

32-
def toTaskCommitDeniedReason: TaskCommitDenied = TaskCommitDenied(jobID, splitID, attemptNumber)
32+
def toTaskCommitDeniedReason: TaskCommitDenied = TaskCommitDenied(jobID, splitID, taskId.toInt)
3333
}

core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,8 +70,8 @@ object SparkHadoopMapRedUtil extends Logging {
7070
if (shouldCoordinateWithDriver) {
7171
val outputCommitCoordinator = SparkEnv.get.outputCommitCoordinator
7272
val ctx = TaskContext.get()
73-
val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), ctx.stageAttemptNumber(),
74-
splitId, ctx.attemptNumber())
73+
val canCommit = outputCommitCoordinator.canCommit(ctx.stageId(), splitId,
74+
ctx.taskAttemptId())
7575

7676
if (canCommit) {
7777
performCommit()
@@ -81,7 +81,7 @@ object SparkHadoopMapRedUtil extends Logging {
8181
logInfo(message)
8282
// We need to abort the task so that the driver can reschedule new attempts, if necessary
8383
committer.abortTask(mrTaskContext)
84-
throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.attemptNumber())
84+
throw new CommitDeniedException(message, ctx.stageId(), splitId, ctx.taskAttemptId())
8585
}
8686
} else {
8787
// Speculation is disabled or a user has chosen to manually bypass the commit coordination

core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1171,9 +1171,8 @@ class DAGScheduler(
11711171

11721172
outputCommitCoordinator.taskCompleted(
11731173
stageId,
1174-
task.stageAttemptId,
11751174
task.partitionId,
1176-
event.taskInfo.attemptNumber, // this is a task attempt number
1175+
event.taskInfo.taskId,
11771176
event.reason)
11781177

11791178
if (!stageIdToStage.contains(task.stageId)) {

core/src/main/scala/org/apache/spark/scheduler/OutputCommitCoordinator.scala

Lines changed: 24 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ private sealed trait OutputCommitCoordinationMessage extends Serializable
2929
private case object StopCoordinator extends OutputCommitCoordinationMessage
3030
private case class AskPermissionToCommitOutput(
3131
stage: Int,
32-
stageAttempt: Int,
3332
partition: Int,
34-
attemptNumber: Int)
33+
taskId: Long)
3534

3635
/**
3736
* Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -49,15 +48,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
4948
// Initialized by SparkEnv
5049
var coordinatorRef: Option[RpcEndpointRef] = None
5150

52-
// Class used to identify a committer. The task ID for a committer is implicitly defined by
53-
// the partition being processed, but the coordinator needs to keep track of both the stage
54-
// attempt and the task attempt, because in some situations the same task may be running
55-
// concurrently in two different attempts of the same stage.
56-
private case class TaskIdentifier(stageAttempt: Int, taskAttempt: Int)
51+
private val NO_AUTHORIZED_COMMITTER: Long = -1L
5752

5853
private case class StageState(numPartitions: Int) {
59-
val authorizedCommitters = Array.fill[TaskIdentifier](numPartitions)(null)
60-
val failures = mutable.Map[Int, mutable.Set[TaskIdentifier]]()
54+
val authorizedCommitters = Array.fill[Long](numPartitions)(NO_AUTHORIZED_COMMITTER)
55+
val failures = mutable.Map[Int, mutable.Set[Long]]()
6156
}
6257

6358
/**
@@ -88,16 +83,14 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
8883
*
8984
* @param stage the stage number
9085
* @param partition the partition number
91-
* @param attemptNumber how many times this task has been attempted
92-
* (see [[TaskContext.attemptNumber()]])
86+
* @param taskId the task asking for permission to commit
9387
* @return true if this task is authorized to commit, false otherwise
9488
*/
9589
def canCommit(
9690
stage: Int,
97-
stageAttempt: Int,
9891
partition: Int,
99-
attemptNumber: Int): Boolean = {
100-
val msg = AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber)
92+
taskId: Long): Boolean = {
93+
val msg = AskPermissionToCommitOutput(stage, partition, taskId)
10194
coordinatorRef match {
10295
case Some(endpointRef) =>
10396
ThreadUtils.awaitResult(endpointRef.ask[Boolean](msg),
@@ -136,9 +129,8 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
136129
// Called by DAGScheduler
137130
private[scheduler] def taskCompleted(
138131
stage: Int,
139-
stageAttempt: Int,
140132
partition: Int,
141-
attemptNumber: Int,
133+
taskId: Long,
142134
reason: TaskEndReason): Unit = synchronized {
143135
val stageState = stageStates.getOrElse(stage, {
144136
logDebug(s"Ignoring task completion for completed stage")
@@ -148,16 +140,14 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
148140
case Success =>
149141
// The task output has been committed successfully
150142
case _: TaskCommitDenied =>
151-
logInfo(s"Task was denied committing, stage: $stage.$stageAttempt, " +
152-
s"partition: $partition, attempt: $attemptNumber")
143+
logInfo(s"Task was denied committing, stage: $stage, partition: $partition, task: $taskId")
153144
case _ =>
154145
// Mark the attempt as failed to blacklist from future commit protocol
155-
val taskId = TaskIdentifier(stageAttempt, attemptNumber)
156146
stageState.failures.getOrElseUpdate(partition, mutable.Set()) += taskId
157147
if (stageState.authorizedCommitters(partition) == taskId) {
158-
logDebug(s"Authorized committer (attemptNumber=$attemptNumber, stage=$stage, " +
148+
logDebug(s"Authorized committer (taskId=$taskId, stage=$stage, " +
159149
s"partition=$partition) failed; clearing lock")
160-
stageState.authorizedCommitters(partition) = null
150+
stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
161151
}
162152
}
163153
}
@@ -173,40 +163,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
173163
// Marked private[scheduler] instead of private so this can be mocked in tests
174164
private[scheduler] def handleAskPermissionToCommit(
175165
stage: Int,
176-
stageAttempt: Int,
177166
partition: Int,
178-
attemptNumber: Int): Boolean = synchronized {
167+
taskId: Long): Boolean = synchronized {
179168
stageStates.get(stage) match {
180-
case Some(state) if attemptFailed(state, stageAttempt, partition, attemptNumber) =>
181-
logInfo(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
182-
s"task attempt $attemptNumber already marked as failed.")
169+
case Some(state) if attemptFailed(state, partition, taskId) =>
170+
logInfo(s"Commit denied for stage=$stage, partition=$partition: " +
171+
s"task $taskId already marked as failed.")
183172
false
184173
case Some(state) =>
185174
val existing = state.authorizedCommitters(partition)
186-
if (existing == null) {
187-
logDebug(s"Commit allowed for stage=$stage.$stageAttempt, partition=$partition, " +
188-
s"task attempt $attemptNumber")
189-
state.authorizedCommitters(partition) = TaskIdentifier(stageAttempt, attemptNumber)
175+
if (existing == NO_AUTHORIZED_COMMITTER) {
176+
logDebug(s"Commit allowed for stage=$stage, partition=$partition, task $taskId")
177+
state.authorizedCommitters(partition) = taskId
190178
true
191179
} else {
192-
logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
180+
logDebug(s"Commit denied for stage=$stage, partition=$partition: " +
193181
s"already committed by $existing")
194182
false
195183
}
196184
case None =>
197-
logDebug(s"Commit denied for stage=$stage.$stageAttempt, partition=$partition: " +
198-
"stage already marked as completed.")
185+
logDebug(s"Commit denied for stage=$stage, partition=$partition: stage already completed.")
199186
false
200187
}
201188
}
202189

203190
private def attemptFailed(
204191
stageState: StageState,
205-
stageAttempt: Int,
206192
partition: Int,
207-
attempt: Int): Boolean = synchronized {
208-
val failInfo = TaskIdentifier(stageAttempt, attempt)
209-
stageState.failures.get(partition).exists(_.contains(failInfo))
193+
taskId: Long): Boolean = synchronized {
194+
stageState.failures.get(partition).exists(_.contains(taskId))
210195
}
211196
}
212197

@@ -226,10 +211,9 @@ private[spark] object OutputCommitCoordinator {
226211
}
227212

228213
override def receiveAndReply(context: RpcCallContext): PartialFunction[Any, Unit] = {
229-
case AskPermissionToCommitOutput(stage, stageAttempt, partition, attemptNumber) =>
214+
case AskPermissionToCommitOutput(stage, partition, taskId) =>
230215
context.reply(
231-
outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition,
232-
attemptNumber))
216+
outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, taskId))
233217
}
234218
}
235219
}

core/src/test/scala/org/apache/spark/scheduler/OutputCommitCoordinatorSuite.scala

Lines changed: 26 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
154154
test("Job should not complete if all commits are denied") {
155155
// Create a mock OutputCommitCoordinator that denies all attempts to commit
156156
doReturn(false).when(outputCommitCoordinator).handleAskPermissionToCommit(
157-
Matchers.any(), Matchers.any(), Matchers.any(), Matchers.any())
157+
Matchers.any(), Matchers.any(), Matchers.any())
158158
val rdd: RDD[Int] = sc.parallelize(Seq(1), 1)
159159
def resultHandler(x: Int, y: Unit): Unit = {}
160160
val futureAction: SimpleFutureAction[Unit] = sc.submitJob[Int, Unit, Unit](rdd,
@@ -170,73 +170,65 @@ class OutputCommitCoordinatorSuite extends SparkFunSuite with BeforeAndAfter {
170170

171171
test("Only authorized committer failures can clear the authorized committer lock (SPARK-6614)") {
172172
val stage: Int = 1
173-
val stageAttempt: Int = 1
174173
val partition: Int = 2
175-
val authorizedCommitter: Int = 3
176-
val nonAuthorizedCommitter: Int = 100
174+
val authorizedCommitter: Long = 3
175+
val nonAuthorizedCommitter: Long = 100
177176
outputCommitCoordinator.stageStart(stage, maxPartitionId = 2)
178177

179-
assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, authorizedCommitter))
180-
assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
181-
nonAuthorizedCommitter))
178+
assert(outputCommitCoordinator.canCommit(stage, partition, authorizedCommitter))
179+
assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter))
182180
// The non-authorized committer fails
183-
outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition,
184-
attemptNumber = nonAuthorizedCommitter, reason = TaskKilled("test"))
181+
outputCommitCoordinator.taskCompleted(stage, partition, nonAuthorizedCommitter,
182+
reason = TaskKilled("test"))
185183
// New tasks should still not be able to commit because the authorized committer has not failed
186-
assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
187-
nonAuthorizedCommitter + 1))
184+
assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 1))
188185
// The authorized committer now fails, clearing the lock
189-
outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition,
190-
attemptNumber = authorizedCommitter, reason = TaskKilled("test"))
186+
outputCommitCoordinator.taskCompleted(stage, partition, authorizedCommitter,
187+
reason = TaskKilled("test"))
191188
// A new task should now be allowed to become the authorized committer
192-
assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
193-
nonAuthorizedCommitter + 2))
189+
assert(outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 2))
194190
// There can only be one authorized committer
195-
assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition,
196-
nonAuthorizedCommitter + 3))
191+
assert(!outputCommitCoordinator.canCommit(stage, partition, nonAuthorizedCommitter + 3))
197192
}
198193

199194
test("SPARK-19631: Do not allow failed attempts to be authorized for committing") {
200195
val stage: Int = 1
201-
val stageAttempt: Int = 1
202196
val partition: Int = 1
203-
val failedAttempt: Int = 0
197+
val failedAttempt: Long = 0L
204198
outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
205-
outputCommitCoordinator.taskCompleted(stage, stageAttempt, partition,
206-
attemptNumber = failedAttempt,
199+
outputCommitCoordinator.taskCompleted(stage, partition, failedAttempt,
207200
reason = ExecutorLostFailure("0", exitCausedByApp = true, None))
208-
assert(!outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt))
209-
assert(outputCommitCoordinator.canCommit(stage, stageAttempt, partition, failedAttempt + 1))
201+
assert(!outputCommitCoordinator.canCommit(stage, partition, failedAttempt))
202+
assert(outputCommitCoordinator.canCommit(stage, partition, failedAttempt + 1))
210203
}
211204

212205
test("SPARK-24589: Differentiate tasks from different stage attempts") {
213206
var stage = 1
214-
val taskAttempt = 1
215207
val partition = 1
216208

217209
outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
218-
assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt))
219-
assert(!outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt))
210+
assert(outputCommitCoordinator.canCommit(stage, partition, 1L))
211+
assert(!outputCommitCoordinator.canCommit(stage, partition, 2L))
220212

221213
// Fail the task in the first attempt, the task in the second attempt should succeed.
222214
stage += 1
223215
outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
224-
outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt,
216+
outputCommitCoordinator.taskCompleted(stage, partition, 1L,
225217
ExecutorLostFailure("0", exitCausedByApp = true, None))
226-
assert(!outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt))
227-
assert(outputCommitCoordinator.canCommit(stage, 2, partition, taskAttempt))
218+
assert(!outputCommitCoordinator.canCommit(stage, partition, 1L))
219+
assert(outputCommitCoordinator.canCommit(stage, partition, 2L))
228220

229221
// Commit the 1st attempt, fail the 2nd attempt, make sure 3rd attempt cannot commit,
230222
// then fail the 1st attempt and make sure the 4th one can commit again.
231223
stage += 1
232224
outputCommitCoordinator.stageStart(stage, maxPartitionId = 1)
233-
assert(outputCommitCoordinator.canCommit(stage, 1, partition, taskAttempt))
234-
outputCommitCoordinator.taskCompleted(stage, 2, partition, taskAttempt,
225+
assert(outputCommitCoordinator.canCommit(stage, partition, 1L))
226+
outputCommitCoordinator.taskCompleted(stage, partition, 2,
235227
ExecutorLostFailure("0", exitCausedByApp = true, None))
236-
assert(!outputCommitCoordinator.canCommit(stage, 3, partition, taskAttempt))
237-
outputCommitCoordinator.taskCompleted(stage, 1, partition, taskAttempt,
228+
assert(!outputCommitCoordinator.canCommit(stage, partition, 3L))
229+
outputCommitCoordinator.taskCompleted(stage, partition, 1L,
238230
ExecutorLostFailure("0", exitCausedByApp = true, None))
239-
assert(outputCommitCoordinator.canCommit(stage, 4, partition, taskAttempt))
231+
assert(outputCommitCoordinator.canCommit(stage, partition, 4L))
240232
}
241233

242234
test("SPARK-24589: Make sure stage state is cleaned up") {

sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,11 +109,11 @@ object DataWritingSparkTask extends Logging {
109109
iter: Iterator[InternalRow],
110110
useCommitCoordinator: Boolean): WriterCommitMessage = {
111111
val stageId = context.stageId()
112-
val stageAttempt = context.stageAttemptNumber()
113112
val partId = context.partitionId()
114113
val attemptId = context.attemptNumber()
115114
val epochId = Option(context.getLocalProperty(MicroBatchExecution.BATCH_ID_KEY)).getOrElse("0")
116115
val dataWriter = writeTask.createDataWriter(partId, attemptId, epochId.toLong)
116+
val taskId = context.taskAttemptId()
117117

118118
// write the data and commit this writer.
119119
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
@@ -123,33 +123,32 @@ object DataWritingSparkTask extends Logging {
123123

124124
val msg = if (useCommitCoordinator) {
125125
val coordinator = SparkEnv.get.outputCommitCoordinator
126-
val commitAuthorized = coordinator.canCommit(stageId, stageAttempt, partId, attemptId)
126+
val commitAuthorized = coordinator.canCommit(stageId, partId, taskId)
127127
if (commitAuthorized) {
128-
logInfo(s"Writer for stage $stageId / $stageAttempt, " +
129-
s"task $partId.$attemptId is authorized to commit.")
128+
logInfo(s"Writer for stage $stageId, part $partId, task $taskId is authorized to commit.")
130129
dataWriter.commit()
131130
} else {
132-
val message = s"Stage $stageId / $stageAttempt, " +
133-
s"task $partId.$attemptId: driver did not authorize commit"
131+
val message = s"Stage $stageId, part $partId, task $taskId: " +
132+
"driver did not authorize commit"
134133
logInfo(message)
135134
// throwing CommitDeniedException will trigger the catch block for abort
136-
throw new CommitDeniedException(message, stageId, partId, attemptId)
135+
throw new CommitDeniedException(message, stageId, partId, taskId)
137136
}
138137

139138
} else {
140139
logInfo(s"Writer for partition ${context.partitionId()} is committing.")
141140
dataWriter.commit()
142141
}
143142

144-
logInfo(s"Writer for stage $stageId, task $partId.$attemptId committed.")
143+
logInfo(s"Writer for stage $stageId, part $partId, task $taskId committed.")
145144

146145
msg
147146

148147
})(catchBlock = {
149148
// If there is an error, abort this writer
150-
logError(s"Writer for stage $stageId, task $partId.$attemptId is aborting.")
149+
logError(s"Writer for stage $stageId, part $partId, task $taskId is aborting.")
151150
dataWriter.abort()
152-
logError(s"Writer for stage $stageId, task $partId.$attemptId aborted.")
151+
logError(s"Writer for stage $stageId, part $partId, task $taskId aborted.")
153152
})
154153
}
155154
}

0 commit comments

Comments
 (0)