@@ -19,6 +19,7 @@ package org.apache.spark.shuffle.hash
19
19
20
20
import org .apache .spark .serializer .Serializer
21
21
import org .apache .spark .shuffle .{BaseShuffleHandle , ShuffleReader }
22
+ import org .apache .spark .util .CompletionIterator
22
23
import org .apache .spark .util .collection .ExternalSorter
23
24
import org .apache .spark .{InterruptibleIterator , SparkEnv , TaskContext }
24
25
@@ -38,7 +39,7 @@ private[spark] class HashShuffleReader[K, C](
38
39
/** Read the combined key-values for this reduce task */
39
40
override def read (): Iterator [Product2 [K , C ]] = {
40
41
val blockStreams = BlockStoreShuffleFetcher .fetchBlockStreams(
41
- handle.shuffleId, startPartition, context)
42
+ handle.shuffleId, startPartition, context)
42
43
43
44
// Wrap the streams for compression based on configuration
44
45
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>
@@ -50,7 +51,11 @@ private[spark] class HashShuffleReader[K, C](
50
51
51
52
// Create a key/value iterator for each stream
52
53
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
53
- serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
54
+ val kvIter = serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
55
+ CompletionIterator [(Any , Any ), Iterator [(Any , Any )]](kvIter, {
56
+ // Close the stream once all the records have been read from it
57
+ wrappedStream.close()
58
+ })
54
59
}
55
60
56
61
// Update read metrics for each record materialized
0 commit comments