Skip to content

Commit d0a1b39

Browse files
committed
Merge pull request #1 from kayousterhout/massie_shuffle-api-cleanup
Proposal for different unit test
2 parents f98a1b9 + 290f1eb commit d0a1b39

File tree

4 files changed

+170
-121
lines changed

4 files changed

+170
-121
lines changed

core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,25 +20,26 @@ package org.apache.spark.shuffle.hash
2020
import java.io.InputStream
2121

2222
import scala.collection.mutable.{ArrayBuffer, HashMap}
23-
import scala.util.{Failure, Success, Try}
23+
import scala.util.{Failure, Success}
2424

2525
import org.apache.spark._
2626
import org.apache.spark.shuffle.FetchFailedException
27-
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
28-
29-
private[hash] class BlockStoreShuffleFetcher extends Logging {
27+
import org.apache.spark.storage.{BlockId, BlockManager, BlockManagerId, ShuffleBlockFetcherIterator,
28+
ShuffleBlockId}
3029

30+
private[hash] object BlockStoreShuffleFetcher extends Logging {
3131
def fetchBlockStreams(
3232
shuffleId: Int,
3333
reduceId: Int,
34-
context: TaskContext)
34+
context: TaskContext,
35+
blockManager: BlockManager,
36+
mapOutputTracker: MapOutputTracker)
3537
: Iterator[(BlockId, InputStream)] =
3638
{
3739
logDebug("Fetching outputs for shuffle %d, reduce %d".format(shuffleId, reduceId))
38-
val blockManager = SparkEnv.get.blockManager
3940

4041
val startTime = System.currentTimeMillis
41-
val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, reduceId)
42+
val statuses = mapOutputTracker.getServerStatuses(shuffleId, reduceId)
4243
logDebug("Fetching map output location for shuffle %d, reduce %d took %d ms".format(
4344
shuffleId, reduceId, System.currentTimeMillis - startTime))
4445

@@ -54,7 +55,7 @@ private[hash] class BlockStoreShuffleFetcher extends Logging {
5455

5556
val blockFetcherItr = new ShuffleBlockFetcherIterator(
5657
context,
57-
SparkEnv.get.blockManager.shuffleClient,
58+
blockManager.shuffleClient,
5859
blockManager,
5960
blocksByAddress,
6061
// Note: we use getSizeAsMb when no suffix is provided for backwards compatibility

core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,10 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import org.apache.spark.storage.BlockManager
21-
import org.apache.spark.{InterruptibleIterator, SparkEnv, TaskContext}
20+
import org.apache.spark.{InterruptibleIterator, MapOutputTracker, SparkEnv, TaskContext}
2221
import org.apache.spark.serializer.Serializer
2322
import org.apache.spark.shuffle.{BaseShuffleHandle, ShuffleReader}
23+
import org.apache.spark.storage.BlockManager
2424
import org.apache.spark.util.CompletionIterator
2525
import org.apache.spark.util.collection.ExternalSorter
2626

@@ -30,7 +30,7 @@ private[spark] class HashShuffleReader[K, C](
3030
endPartition: Int,
3131
context: TaskContext,
3232
blockManager: BlockManager = SparkEnv.get.blockManager,
33-
blockStoreShuffleFetcher: BlockStoreShuffleFetcher = new BlockStoreShuffleFetcher)
33+
mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker)
3434
extends ShuffleReader[K, C]
3535
{
3636
require(endPartition == startPartition + 1,
@@ -40,8 +40,8 @@ private[spark] class HashShuffleReader[K, C](
4040

4141
/** Read the combined key-values for this reduce task */
4242
override def read(): Iterator[Product2[K, C]] = {
43-
val blockStreams = blockStoreShuffleFetcher.fetchBlockStreams(
44-
handle.shuffleId, startPartition, context)
43+
val blockStreams = BlockStoreShuffleFetcher.fetchBlockStreams(
44+
handle.shuffleId, startPartition, context, blockManager, mapOutputTracker)
4545

4646
// Wrap the streams for compression based on configuration
4747
val wrappedStreams = blockStreams.map { case (blockId, inputStream) =>

core/src/test/scala/org/apache/spark/shuffle/hash/HashShuffleManagerSuite.scala

Lines changed: 6 additions & 108 deletions
Original file line numberDiff line numberDiff line change
@@ -17,22 +17,16 @@
1717

1818
package org.apache.spark.shuffle.hash
1919

20-
import java.io._
21-
import java.nio.ByteBuffer
20+
import java.io.{File, FileWriter}
2221

2322
import scala.language.reflectiveCalls
2423

25-
import org.mockito.Matchers.any
26-
import org.mockito.Mockito._
27-
import org.mockito.invocation.InvocationOnMock
28-
import org.mockito.stubbing.Answer
29-
30-
import org.apache.spark._
31-
import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics}
24+
import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite}
25+
import org.apache.spark.executor.ShuffleWriteMetrics
3226
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
33-
import org.apache.spark.serializer._
34-
import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver}
35-
import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment}
27+
import org.apache.spark.serializer.JavaSerializer
28+
import org.apache.spark.shuffle.FileShuffleBlockResolver
29+
import org.apache.spark.storage.{ShuffleBlockId, FileSegment}
3630

3731
class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
3832
private val testConf = new SparkConf(false)
@@ -113,100 +107,4 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
113107
for (i <- 0 until numBytes) writer.write(i)
114108
writer.close()
115109
}
116-
117-
test("HashShuffleReader.read() releases resources and tracks metrics") {
118-
val shuffleId = 1
119-
val numMaps = 2
120-
val numKeyValuePairs = 10
121-
122-
val mockContext = mock(classOf[TaskContext])
123-
124-
val mockTaskMetrics = mock(classOf[TaskMetrics])
125-
val mockReadMetrics = mock(classOf[ShuffleReadMetrics])
126-
when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics)
127-
when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics)
128-
129-
val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher])
130-
131-
val mockDep = mock(classOf[ShuffleDependency[_, _, _]])
132-
when(mockDep.keyOrdering).thenReturn(None)
133-
when(mockDep.aggregator).thenReturn(None)
134-
when(mockDep.serializer).thenReturn(Some(new Serializer {
135-
override def newInstance(): SerializerInstance = new SerializerInstance {
136-
137-
override def deserializeStream(s: InputStream): DeserializationStream =
138-
new DeserializationStream {
139-
override def readObject[T: ClassManifest](): T = null.asInstanceOf[T]
140-
141-
override def close(): Unit = s.close()
142-
143-
private val values = {
144-
for (i <- 0 to numKeyValuePairs * 2) yield i
145-
}.iterator
146-
147-
private def getValueOrEOF(): Int = {
148-
if (values.hasNext) {
149-
values.next()
150-
} else {
151-
throw new EOFException("End of the file: mock deserializeStream")
152-
}
153-
}
154-
155-
// NOTE: the readKey and readValue methods are called by asKeyValueIterator()
156-
// which is wrapped in a NextIterator
157-
override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
158-
159-
override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T]
160-
}
161-
162-
override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T =
163-
null.asInstanceOf[T]
164-
165-
override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0)
166-
167-
override def serializeStream(s: OutputStream): SerializationStream =
168-
null.asInstanceOf[SerializationStream]
169-
170-
override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T]
171-
}
172-
}))
173-
174-
val mockBlockManager = {
175-
// Create a block manager that isn't configured for compression, just returns input stream
176-
val blockManager = mock(classOf[BlockManager])
177-
when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]()))
178-
.thenAnswer(new Answer[InputStream] {
179-
override def answer(invocation: InvocationOnMock): InputStream = {
180-
val blockId = invocation.getArguments()(0).asInstanceOf[BlockId]
181-
val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream]
182-
inputStream
183-
}
184-
})
185-
blockManager
186-
}
187-
188-
val mockInputStream = mock(classOf[InputStream])
189-
when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]()))
190-
.thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream)))
191-
192-
val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep)
193-
194-
val reader = new HashShuffleReader(shuffleHandle, 0, 1,
195-
mockContext, mockBlockManager, mockShuffleFetcher)
196-
197-
val values = reader.read()
198-
// Verify that we're reading the correct values
199-
var numValuesRead = 0
200-
for (((key: Int, value: Int), i) <- values.zipWithIndex) {
201-
assert(key == i * 2)
202-
assert(value == i * 2 + 1)
203-
numValuesRead += 1
204-
}
205-
// Verify that we read the correct number of values
206-
assert(numKeyValuePairs == numValuesRead)
207-
// Verify that our input stream was closed
208-
verify(mockInputStream, times(1)).close()
209-
// Verify that we collected metrics for each key/value pair
210-
verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1)
211-
}
212110
}
Lines changed: 150 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,150 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.hash
19+
20+
import java.io.{ByteArrayOutputStream, InputStream}
21+
import java.nio.ByteBuffer
22+
23+
import org.mockito.Matchers.{eq => meq, _}
24+
import org.mockito.Mockito.{mock, when}
25+
import org.mockito.invocation.InvocationOnMock
26+
import org.mockito.stubbing.Answer
27+
28+
import org.apache.spark._
29+
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
30+
import org.apache.spark.serializer.JavaSerializer
31+
import org.apache.spark.shuffle.BaseShuffleHandle
32+
import org.apache.spark.storage.{BlockManager, BlockManagerId, ShuffleBlockId}
33+
34+
/**
35+
* Wrapper for a managed buffer that keeps track of how many times retain and release are called.
36+
*
37+
* We need to define this class ourselves instead of using a spy because the NioManagedBuffer class
38+
* is final (final classes cannot be spied on).
39+
*/
40+
class RecordingManagedBuffer(underlyingBuffer: NioManagedBuffer) extends ManagedBuffer {
41+
var callsToRetain = 0
42+
var callsToRelease = 0
43+
44+
override def size() = underlyingBuffer.size()
45+
override def nioByteBuffer() = underlyingBuffer.nioByteBuffer()
46+
override def createInputStream() = underlyingBuffer.createInputStream()
47+
override def convertToNetty() = underlyingBuffer.convertToNetty()
48+
49+
override def retain(): ManagedBuffer = {
50+
callsToRetain += 1
51+
underlyingBuffer.retain()
52+
}
53+
override def release(): ManagedBuffer = {
54+
callsToRelease += 1
55+
underlyingBuffer.release()
56+
}
57+
}
58+
59+
class HashShuffleReaderSuite extends SparkFunSuite with LocalSparkContext {
60+
61+
/**
62+
* This test makes sure that, when data is read from a HashShuffleReader, the underlying
63+
* ManagedBuffers that contain the data are eventually released.
64+
*/
65+
test("read() releases resources on completion") {
66+
val testConf = new SparkConf(false)
67+
// Create a SparkContext as a convenient way of setting SparkEnv (needed because some of the
68+
// shuffle code calls SparkEnv.get()).
69+
sc = new SparkContext("local", "test", testConf)
70+
71+
val reduceId = 15
72+
val shuffleId = 22
73+
val numMaps = 6
74+
val keyValuePairsPerMap = 10
75+
val serializer = new JavaSerializer(testConf)
76+
77+
// Make a mock BlockManager that will return RecordingManagedByteBuffers of data, so that we
78+
// can ensure retain() and release() are properly called.
79+
val blockManager = mock(classOf[BlockManager])
80+
81+
// Create a return function to use for the mocked wrapForCompression method that just returns
82+
// the original input stream.
83+
val dummyCompressionFunction = new Answer[InputStream] {
84+
override def answer(invocation: InvocationOnMock) =
85+
invocation.getArguments()(1).asInstanceOf[InputStream]
86+
}
87+
88+
// Create a buffer with some randomly generated key-value pairs to use as the shuffle data
89+
// from each mappers (all mappers return the same shuffle data).
90+
val byteOutputStream = new ByteArrayOutputStream()
91+
val serializationStream = serializer.newInstance().serializeStream(byteOutputStream)
92+
(0 until keyValuePairsPerMap).foreach { i =>
93+
serializationStream.writeKey(i)
94+
serializationStream.writeValue(2*i)
95+
}
96+
97+
// Setup the mocked BlockManager to return RecordingManagedBuffers.
98+
val localBlockManagerId = BlockManagerId("test-client", "test-client", 1)
99+
when(blockManager.blockManagerId).thenReturn(localBlockManagerId)
100+
val buffers = (0 until numMaps).map { mapId =>
101+
// Create a ManagedBuffer with the shuffle data.
102+
val nioBuffer = new NioManagedBuffer(ByteBuffer.wrap(byteOutputStream.toByteArray))
103+
val managedBuffer = new RecordingManagedBuffer(nioBuffer)
104+
105+
// Setup the blockManager mock so the buffer gets returned when the shuffle code tries to
106+
// fetch shuffle data.
107+
val shuffleBlockId = ShuffleBlockId(shuffleId, mapId, reduceId)
108+
when(blockManager.getBlockData(shuffleBlockId)).thenReturn(managedBuffer)
109+
when(blockManager.wrapForCompression(meq(shuffleBlockId), isA(classOf[InputStream])))
110+
.thenAnswer(dummyCompressionFunction)
111+
112+
managedBuffer
113+
}
114+
115+
// Make a mocked MapOutputTracker for the shuffle reader to use to determine what
116+
// shuffle data to read.
117+
val mapOutputTracker = mock(classOf[MapOutputTracker])
118+
// Test a scenario where all data is local, just to avoid creating a bunch of additional mocks
119+
// for the code to read data over the network.
120+
val statuses: Array[(BlockManagerId, Long)] =
121+
Array.fill(numMaps)((localBlockManagerId, byteOutputStream.size()))
122+
when(mapOutputTracker.getServerStatuses(shuffleId, reduceId)).thenReturn(statuses)
123+
124+
// Create a mocked shuffle handle to pass into HashShuffleReader.
125+
val shuffleHandle = {
126+
val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
127+
when(dependency.serializer).thenReturn(Some(serializer))
128+
when(dependency.aggregator).thenReturn(None)
129+
when(dependency.keyOrdering).thenReturn(None)
130+
new BaseShuffleHandle(shuffleId, numMaps, dependency)
131+
}
132+
133+
val shuffleReader = new HashShuffleReader(
134+
shuffleHandle,
135+
reduceId,
136+
reduceId + 1,
137+
new TaskContextImpl(0, 0, 0, 0, null),
138+
blockManager,
139+
mapOutputTracker)
140+
141+
assert(shuffleReader.read().length === keyValuePairsPerMap * numMaps)
142+
143+
// Calling .length above will have exhausted the iterator; make sure that exhausting the
144+
// iterator caused retain and release to be called on each buffer.
145+
buffers.foreach { buffer =>
146+
assert(buffer.callsToRetain === 1)
147+
assert(buffer.callsToRelease === 1)
148+
}
149+
}
150+
}

0 commit comments

Comments
 (0)