17
17
18
18
package org .apache .spark .storage
19
19
20
+ import java .io .InputStream
20
21
import java .util .concurrent .LinkedBlockingQueue
21
22
22
- import scala .collection .mutable .{ArrayBuffer , HashSet , Queue }
23
+ import scala .collection .mutable
24
+ import scala .collection .mutable .ArrayBuffer
23
25
import scala .util .{Failure , Try }
24
26
25
- import org .apache .spark .{Logging , TaskContext }
26
- import org .apache .spark .network .shuffle .{BlockFetchingListener , ShuffleClient }
27
27
import org .apache .spark .network .buffer .ManagedBuffer
28
- import org .apache .spark .serializer .{SerializerInstance , Serializer }
29
- import org .apache .spark .util .{CompletionIterator , Utils }
28
+ import org .apache .spark .network .shuffle .{BlockFetchingListener , ShuffleClient }
29
+ import org .apache .spark .util .Utils
30
+ import org .apache .spark .{Logging , TaskContext }
30
31
31
32
/**
32
33
* An iterator that fetches multiple blocks. For local blocks, it fetches from the local block
33
34
* manager. For remote blocks, it fetches them using the provided BlockTransferService.
34
35
*
35
- * This creates an iterator of (BlockID, values ) tuples so the caller can handle blocks in a
36
- * pipelined fashion as they are received.
36
+ * This creates an iterator of (BlockID, Try[InputStream] ) tuples so the caller can handle blocks
37
+ * in a pipelined fashion as they are received.
37
38
*
38
39
* The implementation throttles the remote fetches to they don't exceed maxBytesInFlight to avoid
39
40
* using too much memory.
@@ -44,7 +45,6 @@ import org.apache.spark.util.{CompletionIterator, Utils}
44
45
* @param blocksByAddress list of blocks to fetch grouped by the [[BlockManagerId ]].
45
46
* For each block we also require the size (in bytes as a long field) in
46
47
* order to throttle the memory usage.
47
- * @param serializer serializer used to deserialize the data.
48
48
* @param maxBytesInFlight max size (in bytes) of remote blocks to fetch at any given point.
49
49
*/
50
50
private [spark]
@@ -53,9 +53,8 @@ final class ShuffleBlockFetcherIterator(
53
53
shuffleClient : ShuffleClient ,
54
54
blockManager : BlockManager ,
55
55
blocksByAddress : Seq [(BlockManagerId , Seq [(BlockId , Long )])],
56
- serializer : Serializer ,
57
56
maxBytesInFlight : Long )
58
- extends Iterator [(BlockId , Try [Iterator [ Any ] ])] with Logging {
57
+ extends Iterator [(BlockId , Try [InputStream ])] with Logging {
59
58
60
59
import ShuffleBlockFetcherIterator ._
61
60
@@ -79,11 +78,11 @@ final class ShuffleBlockFetcherIterator(
79
78
private [this ] val localBlocks = new ArrayBuffer [BlockId ]()
80
79
81
80
/** Remote blocks to fetch, excluding zero-sized blocks. */
82
- private [this ] val remoteBlocks = new HashSet [BlockId ]()
81
+ private [this ] val remoteBlocks = new mutable. HashSet [BlockId ]()
83
82
84
83
/**
85
84
* A queue to hold our results. This turns the asynchronous model provided by
86
- * [[BlockTransferService ]] into a synchronous model (iterator).
85
+ * [[org.apache.spark.network. BlockTransferService ]] into a synchronous model (iterator).
87
86
*/
88
87
private [this ] val results = new LinkedBlockingQueue [FetchResult ]
89
88
@@ -97,14 +96,12 @@ final class ShuffleBlockFetcherIterator(
97
96
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
98
97
* the number of bytes in flight is limited to maxBytesInFlight.
99
98
*/
100
- private [this ] val fetchRequests = new Queue [FetchRequest ]
99
+ private [this ] val fetchRequests = new mutable. Queue [FetchRequest ]
101
100
102
101
/** Current bytes in flight from our requests */
103
102
private [this ] var bytesInFlight = 0L
104
103
105
- private [this ] val shuffleMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
106
-
107
- private [this ] val serializerInstance : SerializerInstance = serializer.newInstance()
104
+ private [this ] val shuffleMetrics = context.taskMetrics().createShuffleReadMetricsForDependency()
108
105
109
106
/**
110
107
* Whether the iterator is still active. If isZombie is true, the callback interface will no
@@ -114,17 +111,23 @@ final class ShuffleBlockFetcherIterator(
114
111
115
112
initialize()
116
113
117
- /**
118
- * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
119
- */
120
- private [this ] def cleanup () {
121
- isZombie = true
114
+ // Decrements the buffer reference count.
115
+ // The currentResult is set to null to prevent releasing the buffer again on cleanup()
116
+ private [storage] def releaseCurrentResultBuffer (): Unit = {
122
117
// Release the current buffer if necessary
123
118
currentResult match {
124
119
case SuccessFetchResult (_, _, buf) => buf.release()
125
120
case _ =>
126
121
}
122
+ currentResult = null
123
+ }
127
124
125
+ /**
126
+ * Mark the iterator as zombie, and release all buffers that haven't been deserialized yet.
127
+ */
128
+ private [this ] def cleanup () {
129
+ isZombie = true
130
+ releaseCurrentResultBuffer()
128
131
// Release buffers in the results queue
129
132
val iter = results.iterator()
130
133
while (iter.hasNext) {
@@ -272,7 +275,7 @@ final class ShuffleBlockFetcherIterator(
272
275
273
276
override def hasNext : Boolean = numBlocksProcessed < numBlocksToFetch
274
277
275
- override def next (): (BlockId , Try [Iterator [ Any ] ]) = {
278
+ override def next (): (BlockId , Try [InputStream ]) = {
276
279
numBlocksProcessed += 1
277
280
val startFetchWait = System .currentTimeMillis()
278
281
currentResult = results.take()
@@ -290,29 +293,51 @@ final class ShuffleBlockFetcherIterator(
290
293
sendRequest(fetchRequests.dequeue())
291
294
}
292
295
293
- val iteratorTry : Try [Iterator [ Any ] ] = result match {
296
+ val iteratorTry : Try [InputStream ] = result match {
294
297
case FailureFetchResult (_, e) =>
295
298
Failure (e)
296
299
case SuccessFetchResult (blockId, _, buf) =>
297
300
// There is a chance that createInputStream can fail (e.g. fetching a local file that does
298
301
// not exist, SPARK-4085). In that case, we should propagate the right exception so
299
302
// the scheduler gets a FetchFailedException.
300
- Try (buf.createInputStream()).map { is0 =>
301
- val is = blockManager.wrapForCompression(blockId, is0)
302
- val iter = serializerInstance.deserializeStream(is).asKeyValueIterator
303
- CompletionIterator [Any , Iterator [Any ]](iter, {
304
- // Once the iterator is exhausted, release the buffer and set currentResult to null
305
- // so we don't release it again in cleanup.
306
- currentResult = null
307
- buf.release()
308
- })
303
+ Try (buf.createInputStream()).map { inputStream =>
304
+ new WrappedInputStream (inputStream, this )
309
305
}
310
306
}
311
307
312
308
(result.blockId, iteratorTry)
313
309
}
314
310
}
315
311
312
+ // Helper class that ensures a ManagerBuffer is released upon InputStream.close()
313
+ private class WrappedInputStream (delegate : InputStream , iterator : ShuffleBlockFetcherIterator )
314
+ extends InputStream {
315
+ private var closed = false
316
+
317
+ override def read (): Int = delegate.read()
318
+
319
+ override def close (): Unit = {
320
+ if (! closed) {
321
+ delegate.close()
322
+ iterator.releaseCurrentResultBuffer()
323
+ closed = true
324
+ }
325
+ }
326
+
327
+ override def available (): Int = delegate.available()
328
+
329
+ override def mark (readlimit : Int ): Unit = delegate.mark(readlimit)
330
+
331
+ override def skip (n : Long ): Long = delegate.skip(n)
332
+
333
+ override def markSupported (): Boolean = delegate.markSupported()
334
+
335
+ override def read (b : Array [Byte ]): Int = delegate.read(b)
336
+
337
+ override def read (b : Array [Byte ], off : Int , len : Int ): Int = delegate.read(b, off, len)
338
+
339
+ override def reset (): Unit = delegate.reset()
340
+ }
316
341
317
342
private [storage]
318
343
object ShuffleBlockFetcherIterator {
0 commit comments