|
17 | 17 |
|
18 | 18 | package org.apache.spark.shuffle.hash
|
19 | 19 |
|
20 |
| -import java.io.{File, FileWriter} |
| 20 | +import java.io._ |
| 21 | +import java.nio.ByteBuffer |
21 | 22 |
|
22 | 23 | import scala.language.reflectiveCalls
|
23 | 24 |
|
24 |
| -import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkEnv, SparkFunSuite} |
25 |
| -import org.apache.spark.executor.ShuffleWriteMetrics |
| 25 | +import org.mockito.Matchers.any |
| 26 | +import org.mockito.Mockito._ |
| 27 | +import org.mockito.invocation.InvocationOnMock |
| 28 | +import org.mockito.stubbing.Answer |
| 29 | + |
| 30 | +import org.apache.spark._ |
| 31 | +import org.apache.spark.executor.{ShuffleReadMetrics, TaskMetrics, ShuffleWriteMetrics} |
26 | 32 | import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
|
27 |
| -import org.apache.spark.serializer.JavaSerializer |
28 |
| -import org.apache.spark.shuffle.FileShuffleBlockResolver |
29 |
| -import org.apache.spark.storage.{ShuffleBlockId, FileSegment} |
| 33 | +import org.apache.spark.serializer._ |
| 34 | +import org.apache.spark.shuffle.{BaseShuffleHandle, FileShuffleBlockResolver} |
| 35 | +import org.apache.spark.storage.{BlockId, BlockManager, ShuffleBlockId, FileSegment} |
30 | 36 |
|
31 | 37 | class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
|
32 | 38 | private val testConf = new SparkConf(false)
|
@@ -107,4 +113,100 @@ class HashShuffleManagerSuite extends SparkFunSuite with LocalSparkContext {
|
107 | 113 | for (i <- 0 until numBytes) writer.write(i)
|
108 | 114 | writer.close()
|
109 | 115 | }
|
| 116 | + |
| 117 | + test("HashShuffleReader.read() releases resources and tracks metrics") { |
| 118 | + val shuffleId = 1 |
| 119 | + val numMaps = 2 |
| 120 | + val numKeyValuePairs = 10 |
| 121 | + |
| 122 | + val mockContext = mock(classOf[TaskContext]) |
| 123 | + |
| 124 | + val mockTaskMetrics = mock(classOf[TaskMetrics]) |
| 125 | + val mockReadMetrics = mock(classOf[ShuffleReadMetrics]) |
| 126 | + when(mockTaskMetrics.createShuffleReadMetricsForDependency()).thenReturn(mockReadMetrics) |
| 127 | + when(mockContext.taskMetrics()).thenReturn(mockTaskMetrics) |
| 128 | + |
| 129 | + val mockShuffleFetcher = mock(classOf[BlockStoreShuffleFetcher]) |
| 130 | + |
| 131 | + val mockDep = mock(classOf[ShuffleDependency[_, _, _]]) |
| 132 | + when(mockDep.keyOrdering).thenReturn(None) |
| 133 | + when(mockDep.aggregator).thenReturn(None) |
| 134 | + when(mockDep.serializer).thenReturn(Some(new Serializer { |
| 135 | + override def newInstance(): SerializerInstance = new SerializerInstance { |
| 136 | + |
| 137 | + override def deserializeStream(s: InputStream): DeserializationStream = |
| 138 | + new DeserializationStream { |
| 139 | + override def readObject[T: ClassManifest](): T = null.asInstanceOf[T] |
| 140 | + |
| 141 | + override def close(): Unit = s.close() |
| 142 | + |
| 143 | + private val values = { |
| 144 | + for (i <- 0 to numKeyValuePairs * 2) yield i |
| 145 | + }.iterator |
| 146 | + |
| 147 | + private def getValueOrEOF(): Int = { |
| 148 | + if (values.hasNext) { |
| 149 | + values.next() |
| 150 | + } else { |
| 151 | + throw new EOFException("End of the file: mock deserializeStream") |
| 152 | + } |
| 153 | + } |
| 154 | + |
| 155 | + // NOTE: the readKey and readValue methods are called by asKeyValueIterator() |
| 156 | + // which is wrapped in a NextIterator |
| 157 | + override def readKey[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] |
| 158 | + |
| 159 | + override def readValue[T: ClassManifest](): T = getValueOrEOF().asInstanceOf[T] |
| 160 | + } |
| 161 | + |
| 162 | + override def deserialize[T: ClassManifest](bytes: ByteBuffer, loader: ClassLoader): T = |
| 163 | + null.asInstanceOf[T] |
| 164 | + |
| 165 | + override def serialize[T: ClassManifest](t: T): ByteBuffer = ByteBuffer.allocate(0) |
| 166 | + |
| 167 | + override def serializeStream(s: OutputStream): SerializationStream = |
| 168 | + null.asInstanceOf[SerializationStream] |
| 169 | + |
| 170 | + override def deserialize[T: ClassManifest](bytes: ByteBuffer): T = null.asInstanceOf[T] |
| 171 | + } |
| 172 | + })) |
| 173 | + |
| 174 | + val mockBlockManager = { |
| 175 | + // Create a block manager that isn't configured for compression, just returns input stream |
| 176 | + val blockManager = mock(classOf[BlockManager]) |
| 177 | + when(blockManager.wrapForCompression(any[BlockId](), any[InputStream]())) |
| 178 | + .thenAnswer(new Answer[InputStream] { |
| 179 | + override def answer(invocation: InvocationOnMock): InputStream = { |
| 180 | + val blockId = invocation.getArguments()(0).asInstanceOf[BlockId] |
| 181 | + val inputStream = invocation.getArguments()(1).asInstanceOf[InputStream] |
| 182 | + inputStream |
| 183 | + } |
| 184 | + }) |
| 185 | + blockManager |
| 186 | + } |
| 187 | + |
| 188 | + val mockInputStream = mock(classOf[InputStream]) |
| 189 | + when(mockShuffleFetcher.fetchBlockStreams(any[Int](), any[Int](), any[TaskContext]())) |
| 190 | + .thenReturn(Iterator.single((mock(classOf[BlockId]), mockInputStream))) |
| 191 | + |
| 192 | + val shuffleHandle = new BaseShuffleHandle(shuffleId, numMaps, mockDep) |
| 193 | + |
| 194 | + val reader = new HashShuffleReader(shuffleHandle, 0, 1, |
| 195 | + mockContext, mockBlockManager, mockShuffleFetcher) |
| 196 | + |
| 197 | + val values = reader.read() |
| 198 | + // Verify that we're reading the correct values |
| 199 | + var numValuesRead = 0 |
| 200 | + for (((key: Int, value: Int), i) <- values.zipWithIndex) { |
| 201 | + assert(key == i * 2) |
| 202 | + assert(value == i * 2 + 1) |
| 203 | + numValuesRead += 1 |
| 204 | + } |
| 205 | + // Verify that we read the correct number of values |
| 206 | + assert(numKeyValuePairs == numValuesRead) |
| 207 | + // Verify that our input stream was closed |
| 208 | + verify(mockInputStream, times(1)).close() |
| 209 | + // Verify that we collected metrics for each key/value pair |
| 210 | + verify(mockReadMetrics, times(numKeyValuePairs)).incRecordsRead(1) |
| 211 | + } |
110 | 212 | }
|
0 commit comments