Skip to content

Commit d76633e

Browse files
jiangxb1987cloud-fan
authored andcommitted
[SPARK-18406][CORE] Race between end-of-task and completion iterator read lock release
## What changes were proposed in this pull request? When a TaskContext is not propagated properly to all child threads for the task, just like the reported cases in this issue, we fail to get to TID from TaskContext and that causes unable to release the lock and assertion failures. To resolve this, we have to explicitly pass the TID value to the `unlock` method. ## How was this patch tested? Add new failing regression test case in `RDDSuite`. Author: Xingbo Jiang <xingbo.jiang@databricks.com> Closes #18076 from jiangxb1987/completion-iterator.
1 parent 9434280 commit d76633e

File tree

4 files changed

+49
-15
lines changed

4 files changed

+49
-15
lines changed

core/src/main/scala/org/apache/spark/network/BlockDataManager.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,5 +46,5 @@ trait BlockDataManager {
4646
/**
4747
* Release locks acquired by [[putBlockData()]] and [[getBlockData()]].
4848
*/
49-
def releaseLock(blockId: BlockId): Unit
49+
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long]): Unit
5050
}

core/src/main/scala/org/apache/spark/storage/BlockInfoManager.scala

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -281,22 +281,27 @@ private[storage] class BlockInfoManager extends Logging {
281281

282282
/**
283283
* Release a lock on the given block.
284+
* In case a TaskContext is not propagated properly to all child threads for the task, we fail to
285+
* get the TID from TaskContext, so we have to explicitly pass the TID value to release the lock.
286+
*
287+
* See SPARK-18406 for more discussion of this issue.
284288
*/
285-
def unlock(blockId: BlockId): Unit = synchronized {
286-
logTrace(s"Task $currentTaskAttemptId releasing lock for $blockId")
289+
def unlock(blockId: BlockId, taskAttemptId: Option[TaskAttemptId] = None): Unit = synchronized {
290+
val taskId = taskAttemptId.getOrElse(currentTaskAttemptId)
291+
logTrace(s"Task $taskId releasing lock for $blockId")
287292
val info = get(blockId).getOrElse {
288293
throw new IllegalStateException(s"Block $blockId not found")
289294
}
290295
if (info.writerTask != BlockInfo.NO_WRITER) {
291296
info.writerTask = BlockInfo.NO_WRITER
292-
writeLocksByTask.removeBinding(currentTaskAttemptId, blockId)
297+
writeLocksByTask.removeBinding(taskId, blockId)
293298
} else {
294299
assert(info.readerCount > 0, s"Block $blockId is not locked for reading")
295300
info.readerCount -= 1
296-
val countsForTask = readLocksByTask(currentTaskAttemptId)
301+
val countsForTask = readLocksByTask(taskId)
297302
val newPinCountForTask: Int = countsForTask.remove(blockId, 1) - 1
298303
assert(newPinCountForTask >= 0,
299-
s"Task $currentTaskAttemptId release lock on block $blockId more times than it acquired it")
304+
s"Task $taskId release lock on block $blockId more times than it acquired it")
300305
}
301306
notifyAll()
302307
}

core/src/main/scala/org/apache/spark/storage/BlockManager.scala

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -501,14 +501,20 @@ private[spark] class BlockManager(
501501
case Some(info) =>
502502
val level = info.level
503503
logDebug(s"Level for block $blockId is $level")
504+
val taskAttemptId = Option(TaskContext.get()).map(_.taskAttemptId())
504505
if (level.useMemory && memoryStore.contains(blockId)) {
505506
val iter: Iterator[Any] = if (level.deserialized) {
506507
memoryStore.getValues(blockId).get
507508
} else {
508509
serializerManager.dataDeserializeStream(
509510
blockId, memoryStore.getBytes(blockId).get.toInputStream())(info.classTag)
510511
}
511-
val ci = CompletionIterator[Any, Iterator[Any]](iter, releaseLock(blockId))
512+
// We need to capture the current taskId in case the iterator completion is triggered
513+
// from a different thread which does not have TaskContext set; see SPARK-18406 for
514+
// discussion.
515+
val ci = CompletionIterator[Any, Iterator[Any]](iter, {
516+
releaseLock(blockId, taskAttemptId)
517+
})
512518
Some(new BlockResult(ci, DataReadMethod.Memory, info.size))
513519
} else if (level.useDisk && diskStore.contains(blockId)) {
514520
val diskData = diskStore.getBytes(blockId)
@@ -525,8 +531,9 @@ private[spark] class BlockManager(
525531
serializerManager.dataDeserializeStream(blockId, stream)(info.classTag)
526532
}
527533
}
528-
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn,
529-
releaseLockAndDispose(blockId, diskData))
534+
val ci = CompletionIterator[Any, Iterator[Any]](iterToReturn, {
535+
releaseLockAndDispose(blockId, diskData, taskAttemptId)
536+
})
530537
Some(new BlockResult(ci, DataReadMethod.Disk, info.size))
531538
} else {
532539
handleLocalReadFailure(blockId)
@@ -711,10 +718,13 @@ private[spark] class BlockManager(
711718
}
712719

713720
/**
714-
* Release a lock on the given block.
721+
* Release a lock on the given block with explicit TID.
722+
* The param `taskAttemptId` should be passed in case we can't get the correct TID from
723+
* TaskContext, for example, the input iterator of a cached RDD iterates to the end in a child
724+
* thread.
715725
*/
716-
def releaseLock(blockId: BlockId): Unit = {
717-
blockInfoManager.unlock(blockId)
726+
def releaseLock(blockId: BlockId, taskAttemptId: Option[Long] = None): Unit = {
727+
blockInfoManager.unlock(blockId, taskAttemptId)
718728
}
719729

720730
/**
@@ -1467,8 +1477,11 @@ private[spark] class BlockManager(
14671477
}
14681478
}
14691479

1470-
def releaseLockAndDispose(blockId: BlockId, data: BlockData): Unit = {
1471-
blockInfoManager.unlock(blockId)
1480+
def releaseLockAndDispose(
1481+
blockId: BlockId,
1482+
data: BlockData,
1483+
taskAttemptId: Option[Long] = None): Unit = {
1484+
releaseLock(blockId, taskAttemptId)
14721485
data.dispose()
14731486
}
14741487

core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ import org.apache.hadoop.mapred.{FileSplit, TextInputFormat}
3030
import org.apache.spark._
3131
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
3232
import org.apache.spark.rdd.RDDSuiteUtils._
33-
import org.apache.spark.util.Utils
33+
import org.apache.spark.util.{ThreadUtils, Utils}
3434

3535
class RDDSuite extends SparkFunSuite with SharedSparkContext {
3636
var tempDir: File = _
@@ -1082,6 +1082,22 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
10821082
assert(totalPartitionCount == 10)
10831083
}
10841084

1085+
test("SPARK-18406: race between end-of-task and completion iterator read lock release") {
1086+
val rdd = sc.parallelize(1 to 1000, 10)
1087+
rdd.cache()
1088+
1089+
rdd.mapPartitions { iter =>
1090+
ThreadUtils.runInNewThread("TestThread") {
1091+
// Iterate to the end of the input iterator, to cause the CompletionIterator completion to
1092+
// fire outside of the task's main thread.
1093+
while (iter.hasNext) {
1094+
iter.next()
1095+
}
1096+
iter
1097+
}
1098+
}.collect()
1099+
}
1100+
10851101
// NOTE
10861102
// Below tests calling sc.stop() have to be the last tests in this suite. If there are tests
10871103
// running after them and if they access sc those tests will fail as sc is already closed, because

0 commit comments

Comments
 (0)