Skip to content

Commit 4abb855

Browse files
committed
Consolidate metric code. Make it clear why InterrubtibleIterator is needed.
There is also some Scala style cleanup in this commit.
1 parent 5c30405 commit 4abb855

File tree

2 files changed

+24
-32
lines changed

2 files changed

+24
-32
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

+14-17
Original file line numberDiff line numberDiff line change
@@ -50,42 +50,39 @@ private[spark] class HashShuffleReader[K, C](
5050
val serializerInstance = ser.newInstance()
5151

5252
// Create a key/value iterator for each stream
53-
val recordIterator = wrappedStreams.flatMap { wrappedStream =>
53+
val recordIter = wrappedStreams.flatMap { wrappedStream =>
5454
serializerInstance.deserializeStream(wrappedStream).asKeyValueIterator
5555
}
5656

57+
// Update the context task metrics for each record read.
5758
val readMetrics = context.taskMetrics.createShuffleReadMetricsForDependency()
58-
// Update read metrics for each record materialized
59-
val metricIter = new InterruptibleIterator[(Any, Any)](context, recordIterator) {
60-
override def next(): (Any, Any) = {
59+
val metricIter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](
60+
recordIter.map(record => {
6161
readMetrics.incRecordsRead(1)
62-
delegate.next()
63-
}
64-
}
62+
record
63+
}),
64+
context.taskMetrics().updateShuffleReadMetrics())
6565

66-
val iter = CompletionIterator[(Any, Any), Iterator[(Any, Any)]](metricIter, {
67-
context.taskMetrics().updateShuffleReadMetrics()
68-
})
66+
// An interruptible iterator must be used here in order to support task cancellation
67+
val interruptibleIter = new InterruptibleIterator[(Any, Any)](context, metricIter)
6968

7069
val aggregatedIter: Iterator[Product2[K, C]] = if (dep.aggregator.isDefined) {
7170
if (dep.mapSideCombine) {
7271
// We are reading values that are already combined
73-
val combinedKeyValuesIterator = iter.asInstanceOf[Iterator[(K, C)]]
74-
new InterruptibleIterator(context,
75-
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context))
72+
val combinedKeyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, C)]]
73+
dep.aggregator.get.combineCombinersByKey(combinedKeyValuesIterator, context)
7674
} else {
7775
// We don't know the value type, but also don't care -- the dependency *should*
7876
// have made sure its compatible w/ this aggregator, which will convert the value
7977
// type to the combined type C
80-
val keyValuesIterator = iter.asInstanceOf[Iterator[(K, Nothing)]]
81-
new InterruptibleIterator(context,
82-
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context))
78+
val keyValuesIterator = interruptibleIter.asInstanceOf[Iterator[(K, Nothing)]]
79+
dep.aggregator.get.combineValuesByKey(keyValuesIterator, context)
8380
}
8481
} else {
8582
require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
8683

8784
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
88-
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
85+
interruptibleIter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
8986
}
9087

9188
// Sort the output if there is a sort ordering defined.

core/src/test/scala/org/apache/spark/storage/ShuffleBlockFetcherIteratorSuite.scala

+10-15
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@ import org.mockito.Mockito._
2828
import org.mockito.invocation.InvocationOnMock
2929
import org.mockito.stubbing.Answer
3030

31+
import org.apache.spark.{SparkFunSuite, TaskContextImpl}
3132
import org.apache.spark.network._
3233
import org.apache.spark.network.buffer.ManagedBuffer
3334
import org.apache.spark.network.shuffle.BlockFetchingListener
34-
import org.apache.spark.{SparkFunSuite, TaskContextImpl}
3535

3636

3737
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
@@ -61,11 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
6161
// Create a mock managed buffer for testing
6262
def createMockManagedBuffer(): ManagedBuffer = {
6363
val mockManagedBuffer = mock(classOf[ManagedBuffer])
64-
when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer[InputStream] {
65-
override def answer(invocation: InvocationOnMock): InputStream = {
66-
mock(classOf[InputStream])
67-
}
68-
})
64+
when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf[InputStream]))
6965
mockManagedBuffer
7066
}
7167

@@ -76,19 +72,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
7672

7773
// Make sure blockManager.getBlockData would return the blocks
7874
val localBlocks = Map[BlockId, ManagedBuffer](
79-
ShuffleBlockId(0, 0, 0) -> mock(classOf[ManagedBuffer]),
80-
ShuffleBlockId(0, 1, 0) -> mock(classOf[ManagedBuffer]),
81-
ShuffleBlockId(0, 2, 0) -> mock(classOf[ManagedBuffer]))
75+
ShuffleBlockId(0, 0, 0) -> createMockManagedBuffer(),
76+
ShuffleBlockId(0, 1, 0) -> createMockManagedBuffer(),
77+
ShuffleBlockId(0, 2, 0) -> createMockManagedBuffer())
8278
localBlocks.foreach { case (blockId, buf) =>
8379
doReturn(buf).when(blockManager).getBlockData(meq(blockId))
8480
}
8581

8682
// Make sure remote blocks would return
8783
val remoteBmId = BlockManagerId("test-client-1", "test-client-1", 2)
8884
val remoteBlocks = Map[BlockId, ManagedBuffer](
89-
ShuffleBlockId(0, 3, 0) -> mock(classOf[ManagedBuffer]),
90-
ShuffleBlockId(0, 4, 0) -> mock(classOf[ManagedBuffer])
91-
)
85+
ShuffleBlockId(0, 3, 0) -> createMockManagedBuffer(),
86+
ShuffleBlockId(0, 4, 0) -> createMockManagedBuffer())
9287

9388
val transfer = createMockTransfer(remoteBlocks)
9489

@@ -109,13 +104,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
109104

110105
for (i <- 0 until 5) {
111106
assert(iterator.hasNext, s"iterator should have 5 elements but actually has $i elements")
112-
val (blockId, subIterator) = iterator.next()
113-
assert(subIterator.isSuccess,
107+
val (blockId, inputStream) = iterator.next()
108+
assert(inputStream.isSuccess,
114109
s"iterator should have 5 elements defined but actually has $i elements")
115110

116111
// Make sure we release buffers when a wrapped input stream is closed.
117112
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
118-
val wrappedInputStream = new BufferReleasingInputStream(mock(classOf[InputStream]), iterator)
113+
val wrappedInputStream = new BufferReleasingInputStream(inputStream.get, iterator)
119114
verify(mockBuf, times(0)).release()
120115
wrappedInputStream.close()
121116
verify(mockBuf, times(1)).release()

0 commit comments

Comments
 (0)