Skip to content

Commit c3302e8

Browse files
jiangxb1987cloud-fan
authored andcommitted
[SPARK-18406][CORE][BACKPORT-2.1] Race between end-of-task and completion iterator read lock release
This is a backport PR of #18076 to 2.1. ## What changes were proposed in this pull request? When a TaskContext is not propagated properly to all child threads for the task, just like the reported cases in this issue, we fail to get to TID from TaskContext and that causes unable to release the lock and assertion failures. To resolve this, we have to explicitly pass the TID value to the `unlock` method. ## How was this patch tested? Add new failing regression test case in `RDDSuite`. Author: Xingbo Jiang <xingbo.jiang@databricks.com> Closes #18099 from jiangxb1987/completion-iterator-2.1.
1 parent 2f68631 commit c3302e8

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

core/src/main/scala/org/apache/spark/network/BlockDataManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ trait BlockDataManager {
4646
/**
4747
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
4848
*/
49-
def releaseLock(blockId: BlockId): Unit
49+
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit
5050
}

core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging {
281281

282282
/**
283283
* Release a lock on the given block.
284+
* In case a TaskContext is not propagated properly to all child threads for the task, we fail to
285+
* get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock.
286+
*
287+
* See SPARK-18406 for more discussion of this issue.
284288
*/
285-
def unlock(blockId: BlockId): Unit = synchronized {
286-
logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId")
289+
def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized {
290+
val taskId = taskAttemptId.getOrElse(currentTaskAttemptId)
291+
logTrace(s"Task $taskId releasing lock for $blockId")
287292
val info = get(blockId).getOrElse {
288293
throw new IllegalStateException(s"Block $blockId not found")
289294
}
290295
if (info.writerTask != BlockInfo.NO_WRITER) {
291296
info.writerTask = BlockInfo.NO_WRITER
292-
writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
297+
writeLocksByTask.removeBinding(taskId, blockId)
293298
} else {
294299
assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
295300
info.readerCount -= 1
296-
val countsForTask = readLocksByTask(currentTaskAttemptId)
301+
val countsForTask = readLocksByTask(taskId)
297302
val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
298303
assert(newPinCountForTask >= 0,
299-
s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it")
304+
s"Task $taskId release lock on block $blockId more times than it acquired it")
300305
}
301306
notifyAll()
302307
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -454,14 +454,20 @@ private[spark] class BlockManager(
454454
case Some(info) =>
455455
val level = info.level
456456
logDebug(s"Level for block $blockId is $level")
457+
val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
457458
if (level.useMemory && memoryStore.contains(blockId)) {
458459
val iter: Iterator[Any] = if (level.deserialized) {
459460
memoryStore.getValues(blockId).get
460461
} else {
461462
serializerManager.dataDeserializeStream(
462463
blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
463464
}
464-
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
465+
// We need to capture the current taskId in case the iterator completion is triggered
466+
// from a different thread which does not have TaskContext set; see SPARK-18406 for
467+
// discussion.
468+
val ci = CompletionIterator[Any, Iterator[Any]](iter, {
469+
releaseLock(blockId, taskAttemptId)
470+
})
465471
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
466472
} else if (level.useDisk && diskStore.contains(blockId)) {
467473
val iterToReturn: Iterator[Any] = {
@@ -478,7 +484,9 @@ private[spark] class BlockManager(
478484
serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
479485
}
480486
}
481-
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, releaseLock(blockId))
487+
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
488+
releaseLock(blockId, taskAttemptId)
489+
})
482490
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
483491
} else {
484492
handleLocalReadFailure(blockId)
@@ -654,10 +662,13 @@ private[spark] class BlockManager(
654662
}
655663

656664
/**
657-
* Release a lock on the given block.
665+
* Release a lock on the given block with explicit TID.
666+
* The param `taskAttemptId` should be passed in case we can't get the correct TID from
667+
* TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child
668+
* thread.
658669
*/
659-
def releaseLock(blockId: BlockId): Unit = {
660-
blockInfoManager.unlock(blockId)
670+
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = {
671+
blockInfoManager.unlock(blockId, taskAttemptId)
661672
}
662673

663674
/**

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
3030
import org.apache.spark._
3131
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
3232
import org.apache.spark.rdd.RDDSuiteUtils._
33-
import org.apache.spark.util.Utils
33+
import org.apache.spark.util.{ThreadUtils, Utils}
3434

3535
class RDDSuite extends SparkFunSuite with SharedSparkContext {
3636
var tempDir: File = _
@@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
10821082
assert(totalPartitionCount == 10)
10831083
}
10841084

1085+
test("SPARK-18406: race between end-of-task and completion iterator read lock release") {
1086+
val rdd = sc.parallelize(1 to 1000, 10)
1087+
rdd.cache()
1088+
1089+
rdd.mapPartitions { iter =>
1090+
ThreadUtils.runInNewThread("TestThread") {
1091+
// Iterate to the end of the input iterator, to cause the CompletionIterator completion to
1092+
// fire outside of the task's main thread.
1093+
while (iter.hasNext) {
1094+
iter.next()
1095+
}
1096+
iter
1097+
}
1098+
}.collect()
1099+
}
1100+
10851101
// NOTE
10861102
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
10871103
// running after them and if they access sc those tests will fail as sc is already closed, because

0 commit comments

Comments
 (0)