@@ -27,7 +27,11 @@ import org.apache.spark.util.{RpcUtils, ThreadUtils}
27
27
private sealed trait OutputCommitCoordinationMessage extends Serializable
28
28
29
29
private case object StopCoordinator extends OutputCommitCoordinationMessage
30
- private case class AskPermissionToCommitOutput (stage : Int , partition : Int , attemptNumber : Int )
30
+ private case class AskPermissionToCommitOutput (
31
+ stage : Int ,
32
+ stageAttempt : Int ,
33
+ partition : Int ,
34
+ attemptNumber : Int )
31
35
32
36
/**
33
37
* Authority that decides whether tasks can commit output to HDFS. Uses a "first committer wins"
@@ -45,13 +49,15 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
45
49
// Initialized by SparkEnv
46
50
var coordinatorRef : Option [RpcEndpointRef ] = None
47
51
48
- private type StageId = Int
49
- private type PartitionId = Int
50
- private type TaskAttemptNumber = Int
51
- private val NO_AUTHORIZED_COMMITTER : TaskAttemptNumber = - 1
52
+ // Class used to identify a committer. The task ID for a committer is implicitly defined by
53
+ // the partition being processed, but the coordinator needs to keep track of both the stage
54
+ // attempt and the task attempt, because in some situations the same task may be running
55
+ // concurrently in two different attempts of the same stage.
56
+ private case class TaskIdentifier (stageAttempt : Int , taskAttempt : Int )
57
+
52
58
private case class StageState (numPartitions : Int ) {
53
- val authorizedCommitters = Array .fill[TaskAttemptNumber ](numPartitions)(NO_AUTHORIZED_COMMITTER )
54
- val failures = mutable.Map [PartitionId , mutable.Set [TaskAttemptNumber ]]()
59
+ val authorizedCommitters = Array .fill[TaskIdentifier ](numPartitions)(null )
60
+ val failures = mutable.Map [Int , mutable.Set [TaskIdentifier ]]()
55
61
}
56
62
57
63
/**
@@ -64,7 +70,7 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
64
70
*
65
71
* Access to this map should be guarded by synchronizing on the OutputCommitCoordinator instance.
66
72
*/
67
- private val stageStates = mutable.Map [StageId , StageState ]()
73
+ private val stageStates = mutable.Map [Int , StageState ]()
68
74
69
75
/**
70
76
* Returns whether the OutputCommitCoordinator's internal data structures are all empty.
@@ -87,10 +93,11 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
87
93
* @return true if this task is authorized to commit, false otherwise
88
94
*/
89
95
def canCommit (
90
- stage : StageId ,
91
- partition : PartitionId ,
92
- attemptNumber : TaskAttemptNumber ): Boolean = {
93
- val msg = AskPermissionToCommitOutput (stage, partition, attemptNumber)
96
+ stage : Int ,
97
+ stageAttempt : Int ,
98
+ partition : Int ,
99
+ attemptNumber : Int ): Boolean = {
100
+ val msg = AskPermissionToCommitOutput (stage, stageAttempt, partition, attemptNumber)
94
101
coordinatorRef match {
95
102
case Some (endpointRef) =>
96
103
ThreadUtils .awaitResult(endpointRef.ask[Boolean ](msg),
@@ -103,26 +110,35 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
103
110
}
104
111
105
112
/**
106
- * Called by the DAGScheduler when a stage starts.
113
+ * Called by the DAGScheduler when a stage starts. Initializes the stage's state if it hasn't
114
+ * yet been initialized.
107
115
*
108
116
* @param stage the stage id.
109
117
* @param maxPartitionId the maximum partition id that could appear in this stage's tasks (i.e.
110
118
* the maximum possible value of `context.partitionId`).
111
119
*/
112
- private [scheduler] def stageStart (stage : StageId , maxPartitionId : Int ): Unit = synchronized {
113
- stageStates(stage) = new StageState (maxPartitionId + 1 )
120
+ private [scheduler] def stageStart (stage : Int , maxPartitionId : Int ): Unit = synchronized {
121
+ stageStates.get(stage) match {
122
+ case Some (state) =>
123
+ require(state.authorizedCommitters.length == maxPartitionId + 1 )
124
+ logInfo(s " Reusing state from previous attempt of stage $stage. " )
125
+
126
+ case _ =>
127
+ stageStates(stage) = new StageState (maxPartitionId + 1 )
128
+ }
114
129
}
115
130
116
131
// Called by DAGScheduler
117
- private [scheduler] def stageEnd (stage : StageId ): Unit = synchronized {
132
+ private [scheduler] def stageEnd (stage : Int ): Unit = synchronized {
118
133
stageStates.remove(stage)
119
134
}
120
135
121
136
// Called by DAGScheduler
122
137
private [scheduler] def taskCompleted (
123
- stage : StageId ,
124
- partition : PartitionId ,
125
- attemptNumber : TaskAttemptNumber ,
138
+ stage : Int ,
139
+ stageAttempt : Int ,
140
+ partition : Int ,
141
+ attemptNumber : Int ,
126
142
reason : TaskEndReason ): Unit = synchronized {
127
143
val stageState = stageStates.getOrElse(stage, {
128
144
logDebug(s " Ignoring task completion for completed stage " )
@@ -131,16 +147,17 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
131
147
reason match {
132
148
case Success =>
133
149
// The task output has been committed successfully
134
- case denied : TaskCommitDenied =>
135
- logInfo(s " Task was denied committing, stage: $stage, partition: $partition , " +
136
- s " attempt: $attemptNumber" )
137
- case otherReason =>
150
+ case _ : TaskCommitDenied =>
151
+ logInfo(s " Task was denied committing, stage: $stage. $stageAttempt , " +
152
+ s " partition: $partition , attempt: $attemptNumber" )
153
+ case _ =>
138
154
// Mark the attempt as failed to blacklist from future commit protocol
139
- stageState.failures.getOrElseUpdate(partition, mutable.Set ()) += attemptNumber
140
- if (stageState.authorizedCommitters(partition) == attemptNumber) {
155
+ val taskId = TaskIdentifier (stageAttempt, attemptNumber)
156
+ stageState.failures.getOrElseUpdate(partition, mutable.Set ()) += taskId
157
+ if (stageState.authorizedCommitters(partition) == taskId) {
141
158
logDebug(s " Authorized committer (attemptNumber= $attemptNumber, stage= $stage, " +
142
159
s " partition= $partition) failed; clearing lock " )
143
- stageState.authorizedCommitters(partition) = NO_AUTHORIZED_COMMITTER
160
+ stageState.authorizedCommitters(partition) = null
144
161
}
145
162
}
146
163
}
@@ -155,47 +172,41 @@ private[spark] class OutputCommitCoordinator(conf: SparkConf, isDriver: Boolean)
155
172
156
173
// Marked private[scheduler] instead of private so this can be mocked in tests
157
174
private [scheduler] def handleAskPermissionToCommit (
158
- stage : StageId ,
159
- partition : PartitionId ,
160
- attemptNumber : TaskAttemptNumber ): Boolean = synchronized {
175
+ stage : Int ,
176
+ stageAttempt : Int ,
177
+ partition : Int ,
178
+ attemptNumber : Int ): Boolean = synchronized {
161
179
stageStates.get(stage) match {
162
- case Some (state) if attemptFailed(state, partition, attemptNumber) =>
163
- logInfo(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
164
- s " partition= $partition as task attempt $attemptNumber has already failed. " )
180
+ case Some (state) if attemptFailed(state, stageAttempt, partition, attemptNumber) =>
181
+ logInfo(s " Commit denied for stage= $stage. $stageAttempt , partition= $partition : " +
182
+ s " task attempt $attemptNumber already marked as failed. " )
165
183
false
166
184
case Some (state) =>
167
- state.authorizedCommitters(partition) match {
168
- case NO_AUTHORIZED_COMMITTER =>
169
- logDebug(s " Authorizing attemptNumber= $attemptNumber to commit for stage= $stage, " +
170
- s " partition= $partition" )
171
- state.authorizedCommitters(partition) = attemptNumber
172
- true
173
- case existingCommitter =>
174
- // Coordinator should be idempotent when receiving AskPermissionToCommit.
175
- if (existingCommitter == attemptNumber) {
176
- logWarning(s " Authorizing duplicate request to commit for " +
177
- s " attemptNumber= $attemptNumber to commit for stage= $stage, " +
178
- s " partition= $partition; existingCommitter = $existingCommitter. " +
179
- s " This can indicate dropped network traffic. " )
180
- true
181
- } else {
182
- logDebug(s " Denying attemptNumber= $attemptNumber to commit for stage= $stage, " +
183
- s " partition= $partition; existingCommitter = $existingCommitter" )
184
- false
185
- }
185
+ val existing = state.authorizedCommitters(partition)
186
+ if (existing == null ) {
187
+ logDebug(s " Commit allowed for stage= $stage. $stageAttempt, partition= $partition, " +
188
+ s " task attempt $attemptNumber" )
189
+ state.authorizedCommitters(partition) = TaskIdentifier (stageAttempt, attemptNumber)
190
+ true
191
+ } else {
192
+ logDebug(s " Commit denied for stage= $stage. $stageAttempt, partition= $partition: " +
193
+ s " already committed by $existing" )
194
+ false
186
195
}
187
196
case None =>
188
- logDebug(s " Stage $stage has completed, so not allowing " +
189
- s " attempt number $attemptNumber of partition $partition to commit " )
197
+ logDebug(s " Commit denied for stage= $stage . $stageAttempt , partition= $partition : " +
198
+ " stage already marked as completed. " )
190
199
false
191
200
}
192
201
}
193
202
194
203
private def attemptFailed (
195
204
stageState : StageState ,
196
- partition : PartitionId ,
197
- attempt : TaskAttemptNumber ): Boolean = synchronized {
198
- stageState.failures.get(partition).exists(_.contains(attempt))
205
+ stageAttempt : Int ,
206
+ partition : Int ,
207
+ attempt : Int ): Boolean = synchronized {
208
+ val failInfo = TaskIdentifier (stageAttempt, attempt)
209
+ stageState.failures.get(partition).exists(_.contains(failInfo))
199
210
}
200
211
}
201
212
@@ -215,9 +226,10 @@ private[spark] object OutputCommitCoordinator {
215
226
}
216
227
217
228
override def receiveAndReply (context : RpcCallContext ): PartialFunction [Any , Unit ] = {
218
- case AskPermissionToCommitOutput (stage, partition, attemptNumber) =>
229
+ case AskPermissionToCommitOutput (stage, stageAttempt, partition, attemptNumber) =>
219
230
context.reply(
220
- outputCommitCoordinator.handleAskPermissionToCommit(stage, partition, attemptNumber))
231
+ outputCommitCoordinator.handleAskPermissionToCommit(stage, stageAttempt, partition,
232
+ attemptNumber))
221
233
}
222
234
}
223
235
}
0 commit comments