@@ -28,10 +28,10 @@ import org.mockito.Mockito._
28
28
import org .mockito .invocation .InvocationOnMock
29
29
import org .mockito .stubbing .Answer
30
30
31
+ import org .apache .spark .{SparkFunSuite , TaskContextImpl }
31
32
import org .apache .spark .network ._
32
33
import org .apache .spark .network .buffer .ManagedBuffer
33
34
import org .apache .spark .network .shuffle .BlockFetchingListener
34
- import org .apache .spark .{SparkFunSuite , TaskContextImpl }
35
35
36
36
37
37
class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
@@ -61,11 +61,7 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
61
61
// Create a mock managed buffer for testing
62
62
def createMockManagedBuffer (): ManagedBuffer = {
63
63
val mockManagedBuffer = mock(classOf [ManagedBuffer ])
64
- when(mockManagedBuffer.createInputStream()).thenAnswer(new Answer [InputStream ] {
65
- override def answer (invocation : InvocationOnMock ): InputStream = {
66
- mock(classOf [InputStream ])
67
- }
68
- })
64
+ when(mockManagedBuffer.createInputStream()).thenReturn(mock(classOf [InputStream ]))
69
65
mockManagedBuffer
70
66
}
71
67
@@ -76,19 +72,18 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
76
72
77
73
// Make sure blockManager.getBlockData would return the blocks
78
74
val localBlocks = Map [BlockId , ManagedBuffer ](
79
- ShuffleBlockId (0 , 0 , 0 ) -> mock( classOf [ ManagedBuffer ] ),
80
- ShuffleBlockId (0 , 1 , 0 ) -> mock( classOf [ ManagedBuffer ] ),
81
- ShuffleBlockId (0 , 2 , 0 ) -> mock( classOf [ ManagedBuffer ] ))
75
+ ShuffleBlockId (0 , 0 , 0 ) -> createMockManagedBuffer( ),
76
+ ShuffleBlockId (0 , 1 , 0 ) -> createMockManagedBuffer( ),
77
+ ShuffleBlockId (0 , 2 , 0 ) -> createMockManagedBuffer( ))
82
78
localBlocks.foreach { case (blockId, buf) =>
83
79
doReturn(buf).when(blockManager).getBlockData(meq(blockId))
84
80
}
85
81
86
82
// Make sure remote blocks would return
87
83
val remoteBmId = BlockManagerId (" test-client-1" , " test-client-1" , 2 )
88
84
val remoteBlocks = Map [BlockId , ManagedBuffer ](
89
- ShuffleBlockId (0 , 3 , 0 ) -> mock(classOf [ManagedBuffer ]),
90
- ShuffleBlockId (0 , 4 , 0 ) -> mock(classOf [ManagedBuffer ])
91
- )
85
+ ShuffleBlockId (0 , 3 , 0 ) -> createMockManagedBuffer(),
86
+ ShuffleBlockId (0 , 4 , 0 ) -> createMockManagedBuffer())
92
87
93
88
val transfer = createMockTransfer(remoteBlocks)
94
89
@@ -109,13 +104,13 @@ class ShuffleBlockFetcherIteratorSuite extends SparkFunSuite {
109
104
110
105
for (i <- 0 until 5 ) {
111
106
assert(iterator.hasNext, s " iterator should have 5 elements but actually has $i elements " )
112
- val (blockId, subIterator ) = iterator.next()
113
- assert(subIterator .isSuccess,
107
+ val (blockId, inputStream ) = iterator.next()
108
+ assert(inputStream .isSuccess,
114
109
s " iterator should have 5 elements defined but actually has $i elements " )
115
110
116
111
// Make sure we release buffers when a wrapped input stream is closed.
117
112
val mockBuf = localBlocks.getOrElse(blockId, remoteBlocks(blockId))
118
- val wrappedInputStream = new BufferReleasingInputStream (mock( classOf [ InputStream ]) , iterator)
113
+ val wrappedInputStream = new BufferReleasingInputStream (inputStream.get , iterator)
119
114
verify(mockBuf, times(0 )).release()
120
115
wrappedInputStream.close()
121
116
verify(mockBuf, times(1 )).release()
0 commit comments