Skip to content

Commit

Permalink
[SPARK-25822][PYSPARK] Fix a race condition when releasing a Python w…
Browse files Browse the repository at this point in the history
…orker

## What changes were proposed in this pull request?

There is a race condition when releasing a Python worker. If `ReaderIterator.handleEndOfDataSection` is not running in the task thread, when a task is early terminated (such as `take(N)`), the task completion listener may close the worker but "handleEndOfDataSection" can still put the worker into the worker pool to reuse.

zsxwing@0e07b48 is a patch to reproduce this issue.

I also found a user reported this in the mail list: http://mail-archives.apache.org/mod_mbox/spark-user/201610.mbox/%3CCAAUq=H+YLUEpd23nwvq13Ms5hOStkhX3ao4f4zQV6sgO5zM-xAmail.gmail.com%3E

This PR fixes the issue by using `compareAndSet` to make sure we will never return a closed worker to the work pool.

## How was this patch tested?

Jenkins.

Closes apache#22816 from zsxwing/fix-socket-closed.

Authored-by: Shixiong Zhu <zsxwing@gmail.com>
Signed-off-by: Takuya UESHIN <ueshin@databricks.com>
  • Loading branch information
zsxwing authored and ueshin committed Oct 26, 2018
1 parent 24e8c27 commit 86d469a
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 14 deletions.
21 changes: 11 additions & 10 deletions core/src/main/scala/org/apache/spark/api/python/PythonRunner.scala
Original file line number Diff line number Diff line change
Expand Up @@ -106,15 +106,17 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", memoryMb.get.toString)
}
val worker: Socket = env.createPythonWorker(pythonExec, envVars.asScala.toMap)
// Whether is the worker released into idle pool
val released = new AtomicBoolean(false)
// Whether is the worker released into idle pool or closed. When any codes try to release or
// close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make
// sure there is only one winner that is going to release or close the worker.
val releasedOrClosed = new AtomicBoolean(false)

// Start a thread to feed the process input from our parent's iterator
val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context)

context.addTaskCompletionListener[Unit] { _ =>
writerThread.shutdownOnTaskCompletion()
if (!reuseWorker || !released.get) {
if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
try {
worker.close()
} catch {
Expand All @@ -131,7 +133,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize))

val stdoutIterator = newReaderIterator(
stream, writerThread, startTime, env, worker, released, context)
stream, writerThread, startTime, env, worker, releasedOrClosed, context)
new InterruptibleIterator(context, stdoutIterator)
}

Expand All @@ -148,7 +150,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: Socket,
released: AtomicBoolean,
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[OUT]

/**
Expand Down Expand Up @@ -392,7 +394,7 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
startTime: Long,
env: SparkEnv,
worker: Socket,
released: AtomicBoolean,
releasedOrClosed: AtomicBoolean,
context: TaskContext)
extends Iterator[OUT] {

Expand Down Expand Up @@ -463,9 +465,8 @@ private[spark] abstract class BasePythonRunner[IN, OUT](
}
// Check whether the worker is ready to be re-used.
if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
if (reuseWorker) {
if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) {
env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
released.set(true)
}
}
eos = true
Expand Down Expand Up @@ -565,9 +566,9 @@ private[spark] class PythonRunner(funcs: Seq[ChainedPythonFunctions])
startTime: Long,
env: SparkEnv,
worker: Socket,
released: AtomicBoolean,
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) {
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {

protected override def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,9 +117,9 @@ class ArrowPythonRunner(
startTime: Long,
env: SparkEnv,
worker: Socket,
released: AtomicBoolean,
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[ColumnarBatch] = {
new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) {
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {

private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
s"stdin reader for $pythonExec", 0, Long.MaxValue)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ class PythonUDFRunner(
startTime: Long,
env: SparkEnv,
worker: Socket,
released: AtomicBoolean,
releasedOrClosed: AtomicBoolean,
context: TaskContext): Iterator[Array[Byte]] = {
new ReaderIterator(stream, writerThread, startTime, env, worker, released, context) {
new ReaderIterator(stream, writerThread, startTime, env, worker, releasedOrClosed, context) {

protected override def read(): Array[Byte] = {
if (writerThread.exception.isDefined) {
Expand Down

0 comments on commit 86d469a

Please sign in to comment.