@@ -24,8 +24,10 @@ import org.mockito.Matchers.any
24
24
25
25
import java .nio .ByteBuffer
26
26
27
+ import scala .collection .mutable .ArrayBuffer
27
28
import scala .concurrent .future
28
29
import scala .concurrent .ExecutionContext .Implicits .global
30
+
29
31
import org .apache .spark ._
30
32
import org .apache .spark .storage .BlockFetcherIterator ._
31
33
import org .apache .spark .network .{ConnectionManager , ConnectionManagerId ,
@@ -34,30 +36,29 @@ import org.apache.spark.network.{ConnectionManager, ConnectionManagerId,
34
36
class BlockFetcherIteratorSuite extends FunSuite with Matchers {
35
37
36
38
test(" block fetch from remote fails using BasicBlockFetcherIterator" ) {
37
- val conf = new SparkConf
38
39
val blockManager = mock(classOf [BlockManager ])
39
40
val connManager = mock(classOf [ConnectionManager ])
40
- val message = Message .createBufferMessage(0 )
41
- message.hasError = true
42
- val someMessage = Some (message)
41
+ when(blockManager.connectionManager).thenReturn(connManager)
43
42
44
43
val f = future {
44
+ val message = Message .createBufferMessage(0 )
45
+ message.hasError = true
46
+ val someMessage = Some (message)
45
47
someMessage
46
48
}
47
- when(blockManager.connectionManager).thenReturn(connManager)
48
49
when(connManager.sendMessageReliably(any(),
49
50
any())).thenReturn(f)
50
51
when(blockManager.futureExecContext).thenReturn(global)
52
+
51
53
when(blockManager.blockManagerId).thenReturn(
52
54
BlockManagerId (" test-client" , " test-client" , 1 , 0 ))
53
55
when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024 )
54
56
55
- val dummyBlId1 = ShuffleBlockId (0 ,0 ,0 )
56
- val dummyBlId2 = ShuffleBlockId (0 ,1 ,0 )
57
+ val blId1 = ShuffleBlockId (0 ,0 ,0 )
58
+ val blId2 = ShuffleBlockId (0 ,1 ,0 )
57
59
val bmId = BlockManagerId (" test-server" , " test-server" ,1 , 0 )
58
60
val blocksByAddress = Seq [(BlockManagerId , Seq [(BlockId , Long )])](
59
- (bmId, Seq ((dummyBlId1, 1 ))),
60
- (bmId, Seq ((dummyBlId2, 1 )))
61
+ (bmId, Seq ((blId1, 1 ), (blId2, 1 )))
61
62
)
62
63
63
64
val iterator = new BasicBlockFetcherIterator (blockManager,
@@ -71,4 +72,58 @@ class BlockFetcherIteratorSuite extends FunSuite with Matchers {
71
72
}
72
73
}
73
74
75
+ test(" block fetch from remote succeed using BasicBlockFetcherIterator" ) {
76
+ val blockManager = mock(classOf [BlockManager ])
77
+ val connManager = mock(classOf [ConnectionManager ])
78
+ when(blockManager.connectionManager).thenReturn(connManager)
79
+
80
+ val blId1 = ShuffleBlockId (0 ,0 ,0 )
81
+ val blId2 = ShuffleBlockId (0 ,1 ,0 )
82
+ val buf1 = ByteBuffer .allocate(4 )
83
+ val buf2 = ByteBuffer .allocate(4 )
84
+ buf1.putInt(1 )
85
+ buf1.flip()
86
+ buf2.putInt(1 )
87
+ buf2.flip()
88
+ val blockMessage1 = BlockMessage .fromGotBlock(GotBlock (blId1, buf1))
89
+ val blockMessage2 = BlockMessage .fromGotBlock(GotBlock (blId2, buf2))
90
+ val blockMessageArray = new BlockMessageArray (
91
+ Seq (blockMessage1, blockMessage2))
92
+
93
+ val bufferMessage = blockMessageArray.toBufferMessage
94
+ val buffer = ByteBuffer .allocate(bufferMessage.size)
95
+ val arrayBuffer = new ArrayBuffer [ByteBuffer ]
96
+ bufferMessage.buffers.foreach{ b =>
97
+ buffer.put(b)
98
+ }
99
+ buffer.flip
100
+ arrayBuffer += buffer
101
+
102
+ val someMessage = Some (Message .createBufferMessage(arrayBuffer))
103
+
104
+ val f = future {
105
+ someMessage
106
+ }
107
+ when(connManager.sendMessageReliably(any(),
108
+ any())).thenReturn(f)
109
+ when(blockManager.futureExecContext).thenReturn(global)
110
+
111
+ when(blockManager.blockManagerId).thenReturn(
112
+ BlockManagerId (" test-client" , " test-client" , 1 , 0 ))
113
+ when(blockManager.maxBytesInFlight).thenReturn(48 * 1024 * 1024 )
114
+
115
+ val bmId = BlockManagerId (" test-server" , " test-server" ,1 , 0 )
116
+ val blocksByAddress = Seq [(BlockManagerId , Seq [(BlockId , Long )])](
117
+ (bmId, Seq ((blId1, 1 ), (blId2, 1 )))
118
+ )
119
+
120
+ val iterator = new BasicBlockFetcherIterator (blockManager,
121
+ blocksByAddress, null )
122
+ iterator.initialize()
123
+ iterator.foreach{
124
+ case (_, r) => {
125
+ (r.isDefined) should be(true )
126
+ }
127
+ }
128
+ }
74
129
}
0 commit comments