Skip to content

Commit 19135f2

Browse files
committed
[SPARK-7884] Allow Spark shuffle APIs to be more customizable
This commit updates the shuffle read path to enable ShuffleReader implementations more control over the deserialization process. The BlockStoreShuffleFetcher.fetch() method has been renamed to BlockStoreShuffleFetcher.fetchBlockStreams(). Previously, this method returned a record iterator; now, it returns an iterator of (BlockId, Try[InputStream]). Deserialization of records is now handled in the ShuffleReader.read() method. This change creates a cleaner separation of concerns and allows implementations of ShuffleReader more flexibility in how records are deserialized.
1 parent 6396cc0 commit 19135f2

File tree

4 files changed

+119
-71
lines changed

4 files changed

+119
-71
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,22 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import scala.collection.mutable.ArrayBuffer
21-
import scala.collection.mutable.HashMap
20+
import java.io.InputStream
21+
22+
import scala.collection.mutable.{ArrayBuffer, HashMap}
2223
import scala.util.{Failure, Success, Try}
2324

2425
import org.apache.spark._
25-
import org.apache.spark.serializer.Serializer
2626
import org.apache.spark.shuffle.FetchFailedException
2727
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
2828
import org.apache.spark.util.CompletionIterator
2929

3030
private[hash] object BlockStoreShuffleFetcher extends Logging {
31-
def fetch[T](
31+
def fetchBlockStreams(
3232
shuffleId: Int,
3333
reduceId: Int,
34-
context: TaskContext,
35-
serializer: Serializer)
36-
: Iterator[T] =
34+
context: TaskContext)
35+
: Iterator[(BlockId, InputStream)] =
3736
{
3837
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
3938
val blockManager = SparkEnv.get.blockManager
@@ -53,12 +52,12 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
5352
(address, splits.map(s => (ShuffleBlockId(shuffleId, s._1, reduceId), s._2)))
5453
}
5554

56-
def unpackBlock(blockPair: (BlockId, Try[Iterator[Any]])) : Iterator[T] = {
55+
def unpackBlock(blockPair: (BlockId, Try[InputStream])) : (BlockId, InputStream) = {
5756
val blockId = blockPair._1
5857
val blockOption = blockPair._2
5958
blockOption match {
60-
case Success(block) => {
61-
block.asInstanceOf[Iterator[T]]
59+
case Success(inputStream) => {
60+
(blockId, inputStream)
6261
}
6362
case Failure(e) => {
6463
blockId match {
@@ -78,21 +77,13 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
7877
SparkEnv.get.blockManager.shuffleClient,
7978
blockManager,
8079
blocksByAddress,
81-
serializer,
8280
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility
8381
SparkEnv.get.conf.getSizeAsMb("spark.reducer.maxSizeInFlight", "48m") * 1024 * 1024)
84-
val itr = blockFetcherItr.flatMap(unpackBlock)
8582

86-
val completionIter = CompletionIterator[T, Iterator[T]](itr, {
87-
context.taskMetrics.updateShuffleReadMetrics()
88-
})
83+
val itr = blockFetcherItr.map(unpackBlock)
8984

90-
new InterruptibleIterator[T](context, completionIter) {
91-
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
92-
override def next(): T = {
93-
readMetrics.incRecordsRead(1)
94-
delegate.next()
95-
}
96-
}
85+
CompletionIterator[(BlockId, InputStream), Iterator[(BlockId, InputStream)]](itr, {
86+
context.taskMetrics().updateShuffleReadMetrics()
87+
})
9788
}
9889
}

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import org.apache.spark.{InterruptibleIterator, TaskContext}
2120
import org.apache.spark.serializer.Serializer
2221
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
2322
import org.apache.spark.util.collection.ExternalSorter
23+
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
2424

2525
private[spark] class HashShuffleReader[K, C](
2626
handle: BaseShuffleHandle[K, _, C],
@@ -33,11 +33,34 @@ private[spark] class HashShuffleReader[K, C](
3333
"Hash shuffle currently only supports fetching one partition")
3434

3535
private val dep = handle.dependency
36+
private val blockManager = SparkEnv.get.blockManager
3637

3738
/** Read the combined key-values for this reduce task */
3839
override def read(): Iterator[Product2[K, C]] = {
40+
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
41+
handle.shuffleId, startPartition, context)
42+
43+
// Wrap the streams for compression based on configuration
44+
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
45+
blockManager.wrapForCompression(blockId, inputStream)
46+
}
47+
3948
val ser = Serializer.getSerializer(dep.serializer)
40-
val iter = BlockStoreShuffleFetcher.fetch(handle.shuffleId, startPartition, context, ser)
49+
val serializerInstance = ser.newInstance()
50+
51+
// Create a key/value iterator for each stream
52+
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
53+
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
54+
}
55+
56+
// Update read metrics for each record materialized
57+
val iter = new InterruptibleIterator[Any](context, recordIterator) {
58+
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
59+
override def next(): Any = {
60+
readMetrics.incRecordsRead(1)
61+
delegate.next()
62+
}
63+
}.asInstanceOf[Iterator[Nothing]]
4164

4265
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
4366
if (dep.mapSideCombine) {

core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala

Lines changed: 57 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -17,23 +17,24 @@
1717

1818
package org.apache.spark.storage
1919

20+
import java.io.InputStream
2021
import java.util.concurrent.LinkedBlockingQueue
2122

22-
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
23+
import scala.collection.mutable
24+
import scala.collection.mutable.ArrayBuffer
2325
import scala.util.{Failure, Try}
2426

25-
import org.apache.spark.{Logging, TaskContext}
26-
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
2727
import org.apache.spark.network.buffer.ManagedBuffer
28-
import org.apache.spark.serializer.{SerializerInstance, Serializer}
29-
import org.apache.spark.util.{CompletionIterator, Utils}
28+
import org.apache.spark.network.shuffle.{BlockFetchingListener, ShuffleClient}
29+
import org.apache.spark.util.Utils
30+
import org.apache.spark.{Logging, TaskContext}
3031

3132
/**
3233
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
3334
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
3435
*
35-
* This creates an iterator of (BlockID, values) tuples so the caller can handle blocks in a
36-
* pipelined fashion as they are received.
36+
* This creates an iterator of (BlockID, Try[InputStream]) tuples so the caller can handle blocks
37+
* in a pipelined fashion as they are received.
3738
*
3839
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
3940
* using too much memory.
@@ -44,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
4445
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId]].
4546
* For each block we also require the size (in bytes as a long field) in
4647
* order to throttle the memory usage.
47-
* @param serializer serializer used to deserialize the data.
4848
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
4949
*/
5050
private[spark]
@@ -53,9 +53,8 @@ final class ShuffleBlockFetcherIterator(
5353
shuffleClient: ShuffleClient,
5454
blockManager: BlockManager,
5555
blocksByAddress: Seq[(BlockManagerId, Seq[(BlockId, Long)])],
56-
serializer: Serializer,
5756
maxBytesInFlight: Long)
58-
extends Iterator[(BlockId, Try[Iterator[Any]])] with Logging {
57+
extends Iterator[(BlockId, Try[InputStream])] with Logging {
5958

6059
import ShuffleBlockFetcherIterator._
6160

@@ -79,11 +78,11 @@ final class ShuffleBlockFetcherIterator(
7978
private[this] val localBlocks = new ArrayBuffer[BlockId]()
8079

8180
/** Remote blocks to fetch, excluding zero-sized blocks. */
82-
private[this] val remoteBlocks = new HashSet[BlockId]()
81+
private[this] val remoteBlocks = new mutable.HashSet[BlockId]()
8382

8483
/**
8584
* A queue to hold our results. This turns the asynchronous model provided by
86-
* [[BlockTransferService]] into a synchronous model (iterator).
85+
* [[org.apache.spark.network.BlockTransferService]] into a synchronous model (iterator).
8786
*/
8887
private[this] val results = new LinkedBlockingQueue[FetchResult]
8988

@@ -97,14 +96,12 @@ final class ShuffleBlockFetcherIterator(
9796
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
9897
* the number of bytes in flight is limited to maxBytesInFlight.
9998
*/
100-
private[this] val fetchRequests = new Queue[FetchRequest]
99+
private[this] val fetchRequests = new mutable.Queue[FetchRequest]
101100

102101
/** Current bytes in flight from our requests */
103102
private[this] var bytesInFlight = 0L
104103

105-
private[this] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
106-
107-
private[this] val serializerInstance: SerializerInstance = serializer.newInstance()
104+
private[this] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
108105

109106
/**
110107
* Whether the iterator is still active. If isZombie is true, the callback interface will no
@@ -114,17 +111,23 @@ final class ShuffleBlockFetcherIterator(
114111

115112
initialize()
116113

117-
/**
118-
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
119-
*/
120-
private[this] def cleanup() {
121-
isZombie = true
114+
// Decrements the buffer reference count.
115+
// The currentResult is set to null to prevent releasing the buffer again on cleanup()
116+
private[storage] def releaseCurrentResultBuffer(): Unit = {
122117
// Release the current buffer if necessary
123118
currentResult match {
124119
case SuccessFetchResult(_, _, buf) => buf.release()
125120
case _ =>
126121
}
122+
currentResult = null
123+
}
127124

125+
/**
126+
* Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
127+
*/
128+
private[this] def cleanup() {
129+
isZombie = true
130+
releaseCurrentResultBuffer()
128131
// Release buffers in the results queue
129132
val iter = results.iterator()
130133
while (iter.hasNext) {
@@ -272,7 +275,7 @@ final class ShuffleBlockFetcherIterator(
272275

273276
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
274277

275-
override def next(): (BlockId, Try[Iterator[Any]]) = {
278+
override def next(): (BlockId, Try[InputStream]) = {
276279
numBlocksProcessed += 1
277280
val startFetchWait = System.currentTimeMillis()
278281
currentResult = results.take()
@@ -290,29 +293,51 @@ final class ShuffleBlockFetcherIterator(
290293
sendRequest(fetchRequests.dequeue())
291294
}
292295

293-
val iteratorTry: Try[Iterator[Any]] = result match {
296+
val iteratorTry: Try[InputStream] = result match {
294297
case FailureFetchResult(_, e) =>
295298
Failure(e)
296299
case SuccessFetchResult(blockId, _, buf) =>
297300
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
298301
// not exist, SPARK-4085). In that case, we should propagate the right exception so
299302
// the scheduler gets a FetchFailedException.
300-
Try(buf.createInputStream()).map { is0 =>
301-
val is = blockManager.wrapForCompression(blockId, is0)
302-
val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
303-
CompletionIterator[Any, Iterator[Any]](iter, {
304-
// Once the iterator is exhausted, release the buffer and set currentResult to null
305-
// so we don't release it again in cleanup.
306-
currentResult = null
307-
buf.release()
308-
})
303+
Try(buf.createInputStream()).map { inputStream =>
304+
new WrappedInputStream(inputStream, this)
309305
}
310306
}
311307

312308
(result.blockId, iteratorTry)
313309
}
314310
}
315311

312+
// Helper class that ensures a ManagerBuffer is released upon InputStream.close()
313+
private class WrappedInputStream(delegate: InputStream, iterator: ShuffleBlockFetcherIterator)
314+
extends InputStream {
315+
private var closed = false
316+
317+
override def read(): Int = delegate.read()
318+
319+
override def close(): Unit = {
320+
if (!closed) {
321+
delegate.close()
322+
iterator.releaseCurrentResultBuffer()
323+
closed = true
324+
}
325+
}
326+
327+
override def available(): Int = delegate.available()
328+
329+
override def mark(readlimit: Int): Unit = delegate.mark(readlimit)
330+
331+
override def skip(n: Long): Long = delegate.skip(n)
332+
333+
override def markSupported(): Boolean = delegate.markSupported()
334+
335+
override def read(b: Array[Byte]): Int = delegate.read(b)
336+
337+
override def read(b: Array[Byte], off: Int, len: Int): Int = delegate.read(b, off, len)
338+
339+
override def reset(): Unit = delegate.reset()
340+
}
316341

317342
private[storage]
318343
object ShuffleBlockFetcherIterator {

0 commit comments

Comments
 (0)