Skip to content

Commit f98a1b9

Browse files
committed
Add test to ensure HashShuffleReader is freeing resources
1 parent a011bfa commit f98a1b9

File tree

3 files changed

+115
-10
lines changed

3 files changed

+115
-10
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import org.apache.spark._
2626
import org.apache.spark.shuffle.FetchFailedException
2727
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
2828

29-
private[hash] object BlockStoreShuffleFetcher extends Logging {
29+
private[hash] class BlockStoreShuffleFetcher extends Logging {
30+
3031
def fetchBlockStreams(
3132
shuffleId: Int,
3233
reduceId: Int,

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20+
import org.apache.spark.storage.BlockManager
2021
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
2122
import org.apache.spark.serializer.Serializer
2223
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
@@ -27,18 +28,19 @@ private[spark] class HashShuffleReader[K, C](
2728
handle: BaseShuffleHandle[K, _, C],
2829
startPartition: Int,
2930
endPartition: Int,
30-
context: TaskContext)
31+
context: TaskContext,
32+
blockManager: BlockManager = SparkEnv.get.blockManager,
33+
blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher)
3134
extends ShuffleReader[K, C]
3235
{
3336
require(endPartition == startPartition + 1,
3437
"Hash shuffle currently only supports fetching one partition")
3538

3639
private val dep = handle.dependency
37-
private val blockManager = SparkEnv.get.blockManager
3840

3941
/** Read the combined key-values for this reduce task */
4042
override def read(): Iterator[Product2[K, C]] = {
41-
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
43+
val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams(
4244
handle.shuffleId, startPartition, context)
4345

4446
// Wrap the streams for compression based on configuration

core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala

Lines changed: 108 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,22 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import java.io.{File, FileWriter}
20+
import java.io._
21+
import java.nio.ByteBuffer
2122

2223
import scala.language.reflectiveCalls
2324

24-
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
25-
import org.apache.spark.executor.ShuffleWriteMetrics
25+
import org.mockito.Matchers.any
26+
import org.mockito.Mockito._
27+
import org.mockito.invocation.InvocationOnMock
28+
import org.mockito.stubbing.Answer
29+
30+
import org.apache.spark._
31+
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics}
2632
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
27-
import org.apache.spark.serializer.JavaSerializer
28-
import org.apache.spark.shuffle.FileShuffleBlockResolver
29-
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
33+
import org.apache.spark.serializer._
34+
import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver}
35+
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment}
3036

3137
class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
3238
private val testConf = new SparkConf(false)
@@ -107,4 +113,100 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
107113
for (i <- 0 until numBytes) writer.write(i)
108114
writer.close()
109115
}
116+
117+
test("HashShuffleReader.read() releases resources and tracks metrics") {
118+
val shuffleId = 1
119+
val numMaps = 2
120+
val numKeyValuePairs = 10
121+
122+
val mockContext = mock(classOf[TaskContext])
123+
124+
val mockTaskMetrics = mock(classOf[TaskMetrics])
125+
val mockReadMetrics = mock(classOf[ShuffleReadMetrics])
126+
when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics)
127+
when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics)
128+
129+
val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher])
130+
131+
val mockDep = mock(classOf[ShuffleDependency[_, _, _]])
132+
when(mockDep.keyOrdering).thenReturn(None)
133+
when(mockDep.aggregator).thenReturn(None)
134+
when(mockDep.serializer).thenReturn(Some(new Serializer {
135+
override def newInstance(): SerializerInstance = new SerializerInstance {
136+
137+
override def deserializeStream(s: InputStream): DeserializationStream =
138+
new DeserializationStream {
139+
override def readObject[T: ClassManifest](): T = null.asInstanceOf[T]
140+
141+
override def close(): Unit = s.close()
142+
143+
private val values = {
144+
for (i <- 0 to numKeyValuePairs * 2) yield i
145+
}.iterator
146+
147+
private def getValueOrEOF(): Int = {
148+
if (values.hasNext) {
149+
values.next()
150+
} else {
151+
throw new EOFException("End of the file: mock deserializeStream")
152+
}
153+
}
154+
155+
// NOTE: the readKey and readValue methods are called by asKeyValueIterator()
156+
// which is wrapped in a NextIterator
157+
override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
158+
159+
override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
160+
}
161+
162+
override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T =
163+
null.asInstanceOf[T]
164+
165+
override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0)
166+
167+
override def serializeStream(s: OutputStream): SerializationStream =
168+
null.asInstanceOf[SerializationStream]
169+
170+
override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T]
171+
}
172+
}))
173+
174+
val mockBlockManager = {
175+
// Create a block manager that isn't configured for compression, just returns input stream
176+
val blockManager = mock(classOf[BlockManager])
177+
when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]()))
178+
.thenAnswer(new Answer[InputStream] {
179+
override def answer(invocation: InvocationOnMock): InputStream = {
180+
val blockId = invocation.getArguments()(0).asInstanceOf[BlockId]
181+
val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream]
182+
inputStream
183+
}
184+
})
185+
blockManager
186+
}
187+
188+
val mockInputStream = mock(classOf[InputStream])
189+
when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]()))
190+
.thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream)))
191+
192+
val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep)
193+
194+
val reader = new HashShuffleReader(shuffleHandle, 0, 1,
195+
mockContext, mockBlockManager, mockShuffleFetcher)
196+
197+
val values = reader.read()
198+
// Verify that we're reading the correct values
199+
var numValuesRead = 0
200+
for (((key: Int, value: Int), i) <- values.zipWithIndex) {
201+
assert(key == i * 2)
202+
assert(value == i * 2 + 1)
203+
numValuesRead += 1
204+
}
205+
// Verify that we read the correct number of values
206+
assert(numKeyValuePairs == numValuesRead)
207+
// Verify that our input stream was closed
208+
verify(mockInputStream, times(1)).close()
209+
// Verify that we collected metrics for each key/value pair
210+
verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1)
211+
}
110212
}

0 commit comments

Comments
 (0)