Skip to content

Commit b3325ad

Browse files
colinmjjGitHub Enterprise
authored andcommitted
[HADP-55679] Fix NPE problem caused by incorrect taskId (apache#631)
1 parent 62a7a28 commit b3325ad

File tree

3 files changed

+18
-17
lines changed

3 files changed

+18
-17
lines changed

core/src/main/scala/org/apache/spark/scheduler/SpillableTaskResultGetter.scala

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@ private[spark] class SpillableTaskResultGetter(sparkEnv: SparkEnv, scheduler: Ta
196196
taskSetManager: TaskSetManager,
197197
taskDataSeq: Seq[(Long, ByteBuffer)]): Unit = {
198198
val tids = taskDataSeq.map(_._1)
199-
val taskId2TaskIdx = scheduler.removeMultiRunningTasks(taskSetManager, tids)
199+
val taskId2TaskPartitionId = scheduler.removeMultiRunningTasks(taskSetManager, tids)
200200

201201
// Killed tasks due to result size exceeds
202202
val killTaskIds = new ArrayBuffer[Long]
@@ -212,7 +212,7 @@ private[spark] class SpillableTaskResultGetter(sparkEnv: SparkEnv, scheduler: Ta
212212
val tid = t._1
213213
val serializedData = t._2
214214
try {
215-
val taskIdx = taskId2TaskIdx.get(tid).get
215+
val taskPartitionId = taskId2TaskPartitionId.get(tid).get
216216
serializer.get().deserialize[TaskResult[_]](serializedData) match {
217217
case directResult: DirectTaskResult[_] =>
218218
val start = System.currentTimeMillis()
@@ -244,7 +244,7 @@ private[spark] class SpillableTaskResultGetter(sparkEnv: SparkEnv, scheduler: Ta
244244
if (resultValue == null) {
245245
logWarning(s"TID ${tid} deserializeDirectResult is null")
246246
// There is possible lock contention to the TaskSetResultStore
247-
store.save(taskIdx, null, 0, taskSetManager.taskSet.id)
247+
store.save(taskPartitionId, null, 0, taskSetManager.taskSet.id)
248248
if (store.isFinished) {
249249
resultStoreMap.remove(taskSetManager.taskSet.stageId.toString)
250250
}
@@ -257,7 +257,7 @@ private[spark] class SpillableTaskResultGetter(sparkEnv: SparkEnv, scheduler: Ta
257257
getLargeResultExecutor.execute(
258258
new SpillDirectResultTask(store,
259259
tid,
260-
taskIdx,
260+
taskPartitionId,
261261
resultValue,
262262
resultSize,
263263
directResult,
@@ -266,7 +266,7 @@ private[spark] class SpillableTaskResultGetter(sparkEnv: SparkEnv, scheduler: Ta
266266
} else {
267267
// There is possible lock contention to the TaskSetResultStore
268268
val (returnResult, spilledSize) = store.save(
269-
taskIdx, resultValue, resultSize, taskSetManager.taskSet.id)
269+
taskPartitionId, resultValue, resultSize, taskSetManager.taskSet.id)
270270
if (spilledSize > 0) {
271271
taskSetManager.totalResultInMemorySize.addAndGet(-spilledSize)
272272
}
@@ -314,7 +314,8 @@ private[spark] class SpillableTaskResultGetter(sparkEnv: SparkEnv, scheduler: Ta
314314
// There is possible lock contention to the TaskSetResultStore
315315
if (store.maybeSpill(size)) {
316316
getLargeResultExecutor.execute(
317-
new FetchLargeResultTask(tid, taskIdx, blockId, size, taskSetManager, store))
317+
new FetchLargeResultTask(tid, taskPartitionId, blockId, size,
318+
taskSetManager, store))
318319
} else {
319320
successInDirectTaskIds += tid
320321
val result =
@@ -324,7 +325,7 @@ private[spark] class SpillableTaskResultGetter(sparkEnv: SparkEnv, scheduler: Ta
324325
failureTaskIds += tid
325326
} else {
326327
val (returnResult, spilledSize) = store.save(
327-
taskIdx, result.value(), size, taskSetManager.taskSet.id)
328+
taskPartitionId, result.value(), size, taskSetManager.taskSet.id)
328329
if (spilledSize > 0) {
329330
taskSetManager.totalResultInMemorySize.addAndGet(-spilledSize)
330331
}

core/src/main/scala/org/apache/spark/scheduler/TaskSchedulerImpl.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1036,9 +1036,9 @@ private[spark] class TaskSchedulerImpl(
10361036
taskSetManager: TaskSetManager,
10371037
tids: Seq[Long]): Map[Long, Int] = synchronized {
10381038
tids.map { tid =>
1039-
val taskIdx = taskSetManager.taskInfos(tid).index
1039+
val taskPartitionId = taskSetManager.taskInfos(tid).partitionId
10401040
taskSetManager.removeRunningTask(tid)
1041-
(tid, taskIdx)
1041+
(tid, taskPartitionId)
10421042
}.toMap
10431043
}
10441044

core/src/main/scala/org/apache/spark/scheduler/TaskSetResultStore.scala

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -66,18 +66,18 @@ private[spark] case class TaskSetResultStore(
6666
}
6767

6868
def save(
69-
taskIdx: Int,
69+
taskPartitionId: Int,
7070
resultValue: Any,
7171
size: Long,
7272
taskSetId: String): (Any, Long) = synchronized {
73-
if (!closed.get && !finished(taskIdx)) {
74-
finished(taskIdx) = true
73+
if (!closed.get && !finished(taskPartitionId)) {
74+
finished(taskPartitionId) = true
7575
numFinished += 1
7676
if (spillContext.nonEmpty) {
77-
spillContext.get.resultHandler(taskIdx, resultValue)
77+
spillContext.get.resultHandler(taskPartitionId, resultValue)
7878
}
79-
bufferedResultMap.put(taskIdx, resultValue)
80-
bufferedResultSize.put(taskIdx, size)
79+
bufferedResultMap.put(taskPartitionId, resultValue)
80+
bufferedResultSize.put(taskPartitionId, size)
8181
totalBufferedSize += size
8282
var spilledSize = 0L;
8383
if (totalBufferedSize > spillThreshold) {
@@ -87,10 +87,10 @@ private[spark] case class TaskSetResultStore(
8787
spilledSize += close(taskSetId)
8888
}
8989
if (isSpilled) (spilledPartitionResults, spilledSize) else (resultValue, 0L)
90-
} else if (!finished(taskIdx)) {
90+
} else if (!finished(taskPartitionId)) {
9191
throw new IllegalStateException("Cannot write to a closed TaskSetResultStore.")
9292
} else {
93-
logInfo(s"Duplicated task result of speculative task $taskSetId:$taskIdx found")
93+
logInfo(s"Duplicated task result of speculative task $taskSetId:$taskPartitionId found")
9494
if (isSpilled) (spilledPartitionResults, 0L) else (resultValue, 0L)
9595
}
9696
}

0 commit comments

Comments
 (0)