Skip to content

Commit

Permalink
[SPARK-22982] Remove unsafe asynchronous close() call from FileDownlo…
Browse files Browse the repository at this point in the history
…adChannel

## What changes were proposed in this pull request?

This patch fixes a severe asynchronous IO bug in Spark's Netty-based file transfer code. At a high-level, the problem is that an unsafe asynchronous `close()` of a pipe's source channel creates a race condition where file transfer code closes a file descriptor then attempts to read from it. If the closed file descriptor's number has been reused by an `open()` call then this invalid read may cause unrelated file operations to return incorrect results. **One manifestation of this problem is incorrect query results.**

For a high-level overview of how file download works, take a look at the control flow in `NettyRpcEnv.openChannel()`: this code creates a pipe to buffer results, then submits an asynchronous stream request to a lower-level TransportClient. The callback passes received data to the sink end of the pipe. The source end of the pipe is passed back to the caller of `openChannel()`. Thus `openChannel()` returns immediately and callers interact with the returned pipe source channel.

Because the underlying stream request is asynchronous, errors may occur after `openChannel()` has returned and after that method's caller has started to `read()` from the returned channel. For example, if a client requests an invalid stream from a remote server then the "stream does not exist" error may not be received from the remote server until after `openChannel()` has returned. In order to be able to propagate the "stream does not exist" error to the file-fetching application thread, this code wraps the pipe's source channel in a special `FileDownloadChannel` which adds an `setError(t: Throwable)` method, then calls this `setError()` method in the FileDownloadCallback's `onFailure` method.

It is possible for `FileDownloadChannel`'s `read()` and `setError()` methods to be called concurrently from different threads: the `setError()` method is called from within the Netty RPC system's stream callback handlers, while the `read()` methods are called from higher-level application code performing remote stream reads.

The problem lies in `setError()`: the existing code closed the wrapped pipe source channel. Because `read()` and `setError()` occur in different threads, this means it is possible for one thread to be calling `source.read()` while another asynchronously calls `source.close()`. Java's IO libraries do not guarantee that this will be safe and, in fact, it's possible for these operations to interleave in such a way that a lower-level `read()` system call occurs right after a `close()` call. In the best-case, this fails as a read of a closed file descriptor; in the worst-case, the file descriptor number has been re-used by an intervening `open()` operation and the read corrupts the result of an unrelated file IO operation being performed by a different thread.

The solution here is to remove the `stream.close()` call in `onError()`: the thread that is performing the `read()` calls is responsible for closing the stream in a `finally` block, so there's no need to close it here. If that thread is blocked in a `read()` then it will become unblocked when the sink end of the pipe is closed in `FileDownloadCallback.onFailure()`.

After making this change, we also need to refine the `read()` method to always check for a `setError()` result, even if the underlying channel `read()` call has succeeded.

This patch also makes a slight cleanup to a dodgy-looking `catch e: Exception` block to use a safer `try-finally` error handling idiom.

This bug was introduced in SPARK-11956 / apache#9941 and is present in Spark 1.6.0+.

## How was this patch tested?

This fix was tested manually against a workload which non-deterministically hit this bug.

Author: Josh Rosen <joshrosen@databricks.com>

Closes apache#20179 from JoshRosen/SPARK-22982-fix-unsafe-async-io-in-file-download-channel.
  • Loading branch information
JoshRosen authored and cloud-fan committed Jan 10, 2018
1 parent e599837 commit edf0a48
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 19 deletions.
37 changes: 22 additions & 15 deletions core/src/main/scala/org/apache/spark/rpc/netty/NettyRpcEnv.scala
Original file line number Diff line number Diff line change
Expand Up @@ -332,16 +332,14 @@ private[netty] class NettyRpcEnv(

val pipe = Pipe.open()
val source = new FileDownloadChannel(pipe.source())
try {
Utils.tryWithSafeFinallyAndFailureCallbacks(block = {
val client = downloadClient(parsedUri.getHost(), parsedUri.getPort())
val callback = new FileDownloadCallback(pipe.sink(), source, client)
client.stream(parsedUri.getPath(), callback)
} catch {
case e: Exception =>
pipe.sink().close()
source.close()
throw e
}
})(catchBlock = {
pipe.sink().close()
source.close()
})

source
}
Expand Down Expand Up @@ -370,24 +368,33 @@ private[netty] class NettyRpcEnv(
fileDownloadFactory.createClient(host, port)
}

private class FileDownloadChannel(source: ReadableByteChannel) extends ReadableByteChannel {
private class FileDownloadChannel(source: Pipe.SourceChannel) extends ReadableByteChannel {

@volatile private var error: Throwable = _

def setError(e: Throwable): Unit = {
// This setError callback is invoked by internal RPC threads in order to propagate remote
// exceptions to application-level threads which are reading from this channel. When an
// RPC error occurs, the RPC system will call setError() and then will close the
// Pipe.SinkChannel corresponding to the other end of the `source` pipe. Closing of the pipe
// sink will cause `source.read()` operations to return EOF, unblocking the application-level
// reading thread. Thus there is no need to actually call `source.close()` here in the
// onError() callback and, in fact, calling it here would be dangerous because the close()
// would be asynchronous with respect to the read() call and could trigger race-conditions
// that lead to data corruption. See the PR for SPARK-22982 for more details on this topic.
error = e
source.close()
}

override def read(dst: ByteBuffer): Int = {
Try(source.read(dst)) match {
// See the documentation above in setError(): if an RPC error has occurred then setError()
// will be called to propagate the RPC error and then `source`'s corresponding
// Pipe.SinkChannel will be closed, unblocking this read. In that case, we want to propagate
// the remote RPC exception (and not any exceptions triggered by the pipe close, such as
// ChannelClosedException), hence this `error != null` check:
case _ if error != null => throw error
case Success(bytesRead) => bytesRead
case Failure(readErr) =>
if (error != null) {
throw error
} else {
throw readErr
}
case Failure(readErr) => throw readErr
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
package org.apache.spark.shuffle

import java.io._

import com.google.common.io.ByteStreams
import java.nio.channels.Channels
import java.nio.file.Files

import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.internal.Logging
Expand Down Expand Up @@ -196,11 +196,24 @@ private[spark] class IndexShuffleBlockResolver(
// find out the consolidated file, then the offset within that from our index
val indexFile = getIndexFile(blockId.shuffleId, blockId.mapId)

val in = new DataInputStream(new FileInputStream(indexFile))
// SPARK-22982: if this FileInputStream's position is seeked forward by another piece of code
// which is incorrectly using our file descriptor then this code will fetch the wrong offsets
// (which may cause a reducer to be sent a different reducer's data). The explicit position
// checks added here were a useful debugging aid during SPARK-22982 and may help prevent this
// class of issue from re-occurring in the future which is why they are left here even though
// SPARK-22982 is fixed.
val channel = Files.newByteChannel(indexFile.toPath)
channel.position(blockId.reduceId * 8)
val in = new DataInputStream(Channels.newInputStream(channel))
try {
ByteStreams.skipFully(in, blockId.reduceId * 8)
val offset = in.readLong()
val nextOffset = in.readLong()
val actualPosition = channel.position()
val expectedPosition = blockId.reduceId * 8 + 16
if (actualPosition != expectedPosition) {
throw new Exception(s"SPARK-22982: Incorrect channel position after index file reads: " +
s"expected $expectedPosition but actual position was $actualPosition.")
}
new FileSegmentManagedBuffer(
transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
Expand Down

0 comments on commit edf0a48

Please sign in to comment.