Skip to content

[SPARK-7884] Move block deserialization from BlockStoreShuffleFetcher to ShuffleReader #6423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
19135f2
[SPARK-7884] Allow Spark shuffle APIs to be more customizable
massie May 27, 2015
b70c945
Make BlockStoreShuffleFetcher visible to shuffle package
massie Jun 2, 2015
208b7a5
Small code style changes
massie Jun 8, 2015
7c8f73e
Close Block InputStream immediately after all records are read
massie Jun 9, 2015
01e8721
Explicitly cast iterator in branches for type clarity
massie Jun 9, 2015
7e8e0fe
Minor Scala style fixes
massie Jun 9, 2015
f93841e
Update shuffle read metrics in ShuffleReader instead of BlockStoreShu…
massie Jun 9, 2015
28f8085
Small import nit
massie Jun 10, 2015
7eedd1d
Small Scala import cleanup
massie Jun 10, 2015
5c30405
Return visibility of BlockStoreShuffleFetcher to private[hash]
massie Jun 10, 2015
4abb855
Consolidate metric code. Make it clear why InterrubtibleIterator is n…
massie Jun 10, 2015
f458489
Remove unnecessary map() on return Iterator
massie Jun 11, 2015
7429a98
Update tests to check that BufferReleasingStream is closing delegate …
massie Jun 12, 2015
4ea1712
Small code cleanup for readability
massie Jun 12, 2015
a011bfa
Use PrivateMethodTester on check that delegate stream is closed
massie Jun 18, 2015
f98a1b9
Add test to ensure HashShuffleReader is freeing resources
massie Jun 22, 2015
5186da0
Revert "Add test to ensure HashShuffleReader is freeing resources"
kayousterhout Jun 23, 2015
290f1eb
Added test for HashShuffleReader.read()
kayousterhout Jun 23, 2015
d0a1b39
Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup
massie Jun 23, 2015
8b0632c
Minor Scala style fixes
massie Jun 23, 2015
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,29 +17,29 @@

package org.apache.spark.shuffle.hash

import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.util.{Failure, Success, Try}
import java.io.InputStream

import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.{Failure, Success}

import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
import org.apache.spark.util.CompletionIterator
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
ShuffleBlockId}

private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
def fetchBlockStreams(
shuffleId: Int,
reduceId: Int,
context: TaskContext,
serializer: Serializer)
: Iterator[T] =
blockManager: BlockManager,
mapOutputTracker: MapOutputTracker)
: Iterator[(BlockId, InputStream)] =
{
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
val blockManager = SparkEnv.get.blockManager

val startTime = System.currentTimeMillis
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
shuffleId, reduceId, System.currentTimeMillis - startTime))

Expand All @@ -53,12 +53,21 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
}

def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
blockManager.shuffleClient,
blockManager,
blocksByAddress,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)

// Make sure that fetch failures are wrapped inside a FetchFailedException for the scheduler
blockFetcherItr.map { blockPair =>
val blockId = blockPair._1
val blockOption = blockPair._2
blockOption match {
case Success(block) => {
block.asInstanceOf[Iterator[T]]
case Success(inputStream) => {
(blockId, inputStream)
}
case Failure(e) => {
blockId match {
Expand All @@ -72,27 +81,5 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
}
}
}

val blockFetcherItr = new ShuffleBlockFetcherIterator(
context,
SparkEnv.get.blockManager.shuffleClient,
blockManager,
blocksByAddress,
serializer,
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
val itr = blockFetcherItr.flatMap(unpackBlock)

val completionIter = CompletionIterator[T, Iterator[T]](itr, {
context.taskMetrics.updateShuffleReadMetrics()
})

new InterruptibleIterator[T](context, completionIter) {
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
override def next(): T = {
readMetrics.incRecordsRead(1)
delegate.next()
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,20 @@

package org.apache.spark.shuffle.hash

import org.apache.spark.{InterruptibleIterator, TaskContext}
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
import org.apache.spark.storage.BlockManager
import org.apache.spark.util.CompletionIterator
import org.apache.spark.util.collection.ExternalSorter

private[spark] class HashShuffleReader[K, C](
handle: BaseShuffleHandle[K, _, C],
startPartition: Int,
endPartition: Int,
context: TaskContext)
context: TaskContext,
blockManager: BlockManager = SparkEnv.get.blockManager,
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
extends ShuffleReader[K, C]
{
require(endPartition == startPartition + 1,
Expand All @@ -36,20 +40,52 @@ private[spark] class HashShuffleReader[K, C](

/** Read the combined key-values for this reduce task */
override def read(): Iterator[Product2[K, C]] = {
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)

// Wrap the streams for compression based on configuration
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
blockManager.wrapForCompression(blockId, inputStream)
}

val ser = Serializer.getSerializer(dep.serializer)
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
val serializerInstance = ser.newInstance()

// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
// Note: the asKeyValueIterator below wraps a key/value iterator inside of a
// NextIterator. The NextIterator makes sure that close() is called on the
// underlying InputStream when all records have been read.
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
}

// Update the context task metrics for each record read.
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
recordIter.map(record => {
readMetrics.incRecordsRead(1)
record
}),
context.taskMetrics().updateShuffleReadMetrics())

// An interruptible iterator must be used here in order to support task cancellation
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)

val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
if (dep.mapSideCombine) {
new InterruptibleIterator(context, dep.aggregator.get.combineCombinersByKey(iter, context))
// We are reading values that are already combined
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
// We don't know the value type, but also don't care -- the dependency *should*
// have made sure its compatible w/ this aggregator, which will convert the value
// type to the combined type C
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
}
} else {
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")

// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]]
}

// Sort the output if there is a sort ordering defined.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,23 +17,23 @@

package org.apache.spark.storage

import java.io.InputStream
import java.util.concurrent.LinkedBlockingQueue

import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
import scala.util.{Failure, Try}

import org.apache.spark.{Logging, TaskContext}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.network.buffer.ManagedBuffer
import org.apache.spark.serializer.{SerializerInstance, Serializer}
import org.apache.spark.util.{CompletionIterator, Utils}
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
import org.apache.spark.util.Utils

/**
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
*
* This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
* pipelined fashion as they are received.
* This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
* in a pipelined fashion as they are received.
*
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
* using too much memory.
Expand All @@ -44,7 +44,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
* For each block we also require the size (in bytes as a long field) in
* order to throttle the memory usage.
* @param serializer serializer used to deserialize the data.
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
*/
private[spark]
Expand All @@ -53,9 +52,8 @@ final class ShuffleBlockFetcherIterator(
shuffleClient: ShuffleClient,
blockManager: BlockManager,
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
serializer: Serializer,
maxBytesInFlight: Long)
extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
extends Iterator[(BlockId, Try[InputStream])] with Logging {

import ShuffleBlockFetcherIterator._

Expand Down Expand Up @@ -83,7 +81,7 @@ final class ShuffleBlockFetcherIterator(

/**
* A queue to hold our results. This turns the asynchronous model provided by
* [[BlockTransferService]] into a synchronous model (iterator).
* [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
*/
private[this] val results = new LinkedBlockingQueue[FetchResult]

Expand All @@ -102,9 +100,7 @@ final class ShuffleBlockFetcherIterator(
/** Current bytes in flight from our requests */
private[this] var bytesInFlight = 0L

private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()

private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()

/**
* Whether the iterator is still active. If isZombie is true, the callback interface will no
Expand All @@ -114,17 +110,23 @@ final class ShuffleBlockFetcherIterator(

initialize()

/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
private[this] def cleanup() {
isZombie = true
// Decrements the buffer reference count.
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
private[storage] def releaseCurrentResultBuffer(): Unit = {
// Release the current buffer if necessary
currentResult match {
case SuccessFetchResult(_, _, buf) => buf.release()
case _ =>
}
currentResult = null
}

/**
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
*/
private[this] def cleanup() {
isZombie = true
releaseCurrentResultBuffer()
// Release buffers in the results queue
val iter = results.iterator()
while (iter.hasNext) {
Expand Down Expand Up @@ -272,7 +274,13 @@ final class ShuffleBlockFetcherIterator(

override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch

override def next(): (BlockId, Try[Iterator[Any]]) = {
/**
* Fetches the next (BlockId, Try[InputStream]). If a task fails, the ManagedBuffers
* underlying each InputStream will be freed by the cleanup() method registered with the
* TaskCompletionListener. However, callers should close() these InputStreams
* as soon as they are no longer needed, in order to release memory as early as possible.
*/
override def next(): (BlockId, Try[InputStream]) = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just started writing a big comment about why I thought this was unnecessary because I thought the TaskCompletionListener would take care of it for us ... and then I realized that this was necessary so that we release buffers as soon as finish with them as we fetch a bunch of blocks.

would you mind adding a small doc here that callers of this class should always be sure to wrap the results in a CompletionIterator which closes the InputStream, to be sure buffers are released as soon as possible? Would help callers and future readers of this code

numBlocksProcessed += 1
val startFetchWait = System.currentTimeMillis()
currentResult = results.take()
Expand All @@ -290,29 +298,56 @@ final class ShuffleBlockFetcherIterator(
sendRequest(fetchRequests.dequeue())
}

val iteratorTry: Try[Iterator[Any]] = result match {
val iteratorTry: Try[InputStream] = result match {
case FailureFetchResult(_, e) =>
Failure(e)
case SuccessFetchResult(blockId, _, buf) =>
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
// not exist, SPARK-4085). In that case, we should propagate the right exception so
// the scheduler gets a FetchFailedException.
Try(buf.createInputStream()).map { is0 =>
val is = blockManager.wrapForCompression(blockId, is0)
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
CompletionIterator[Any, Iterator[Any]](iter, {
// Once the iterator is exhausted, release the buffer and set currentResult to null
// so we don't release it again in cleanup.
currentResult = null
buf.release()
})
Try(buf.createInputStream()).map { inputStream =>
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might also treat this patch as an opportunity to revisit why we're using Try here. It might be fine to keep Try as the return type but I'm not necessarily convinced that we should be calling Try.apply() here since I think it obscures whether we'll need to perform any cleanup after errors (for instance, do we need to free buf? Is buf guaranteed to be non-null or could this fail with an NPE on the buf.createInputStream() call? I feel that the Try.apply() makes it easy to overlook these concerns.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The ShuffleBlockFetcherIterator has no information about server statuses from the map output tracker, shuffle IDs, etc. Using Try allows the BlockStoreShuffleFetcher to reformat exceptions as a FetchFailedException which is the right exception to return to the scheduler.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is technically not your change -- but do you know what happens if the stream is not consumed in full in a task? Does that lead to memory leaks because close on the stream is never called?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need to worry about a memory leak when the task exits with success or failure since there is a cleanup method registered with the task context, e.g.

// Add a task completion callback (called in both success case and failure case) to cleanup.
context.addTaskCompletionListener(_ => cleanup())

However, you're correct that there would be a memory (and file handle) leak, if the InputStream isn't closed in the ShuffleReader. This PR prevents that since serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator returns a NextIterator which closes the stream when the last record is read.

To be more defensive and potentially simplify the code, it might make sense to have a call to ShuffleBlockFetcherIterator.next() to not just return the next InputStream but also close() the last one. This would prevent callers from having more than one InputStream open at a time but I don't think we want that anyway?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rxin Let me know if you'd like this change to be made to ShuffleBlockFetcherIterator.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might consider deferring this change to a followup PR; want to file a JIRA issue so that we don't forget to eventually follow up?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

new BufferReleasingInputStream(inputStream, this)
}
}

(result.blockId, iteratorTry)
}
}

/**
* Helper class that ensures a ManagedBuffer is release upon InputStream.close()
* Note: the delegate parameter is private[storage] to make it available to tests.
*/
private class BufferReleasingInputStream(
private val delegate: InputStream,
private val iterator: ShuffleBlockFetcherIterator)
extends InputStream {
private[this] var closed = false

override def read(): Int = delegate.read()

override def close(): Unit = {
if (!closed) {
delegate.close()
iterator.releaseCurrentResultBuffer()
closed = true
}
}

override def available(): Int = delegate.available()

override def mark(readlimit: Int): Unit = delegate.mark(readlimit)

override def skip(n: Long): Long = delegate.skip(n)

override def markSupported(): Boolean = delegate.markSupported()

override def read(b: Array[Byte]): Int = delegate.read(b)

override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)

override def reset(): Unit = delegate.reset()
}

private[storage]
object ShuffleBlockFetcherIterator {
Expand Down
Loading