20
20
import static io .netty .handler .codec .http2 .Http2TestUtil .runInChannel ;
21
21
import static io .netty .util .CharsetUtil .UTF_8 ;
22
22
import static java .util .concurrent .TimeUnit .SECONDS ;
23
+ import static org .junit .Assert .assertArrayEquals ;
23
24
import static org .junit .Assert .assertEquals ;
24
25
import static org .junit .Assert .assertTrue ;
25
26
import static org .mockito .Matchers .any ;
26
27
import static org .mockito .Matchers .anyInt ;
27
28
import static org .mockito .Matchers .eq ;
29
+ import static org .mockito .Mockito .doAnswer ;
28
30
import static org .mockito .Mockito .times ;
29
31
import static org .mockito .Mockito .verify ;
30
32
import io .netty .bootstrap .Bootstrap ;
41
43
import io .netty .channel .socket .nio .NioServerSocketChannel ;
42
44
import io .netty .channel .socket .nio .NioSocketChannel ;
43
45
import io .netty .handler .codec .http2 .Http2TestUtil .Http2Runnable ;
44
- import io .netty .util .CharsetUtil ;
45
46
import io .netty .util .NetUtil ;
46
47
import io .netty .util .concurrent .Future ;
47
48
49
+ import java .io .ByteArrayOutputStream ;
48
50
import java .net .InetSocketAddress ;
51
+ import java .util .ArrayList ;
49
52
import java .util .List ;
50
53
import java .util .Random ;
51
54
import java .util .concurrent .CountDownLatch ;
54
57
import org .junit .After ;
55
58
import org .junit .Before ;
56
59
import org .junit .Test ;
57
- import org .mockito .ArgumentCaptor ;
58
60
import org .mockito .Mock ;
61
+ import org .mockito .Mockito ;
59
62
import org .mockito .MockitoAnnotations ;
63
+ import org .mockito .invocation .InvocationOnMock ;
64
+ import org .mockito .stubbing .Answer ;
60
65
61
66
/**
62
67
* Tests the full HTTP/2 framing stack including the connection and preface handlers.
@@ -144,7 +149,16 @@ public void flowControlProperlyChunksLargeMessage() throws Exception {
144
149
final byte [] bytes = new byte [length ];
145
150
new Random ().nextBytes (bytes );
146
151
final ByteBuf data = Unpooled .wrappedBuffer (bytes );
147
- List <ByteBuf > capturedData = null ;
152
+ final ByteArrayOutputStream out = new ByteArrayOutputStream (length );
153
+ doAnswer (new Answer <Void >() {
154
+ @ Override
155
+ public Void answer (InvocationOnMock in ) throws Throwable {
156
+ ByteBuf buf = (ByteBuf ) in .getArguments ()[2 ];
157
+ buf .readBytes (out , buf .readableBytes ());
158
+ return null ;
159
+ }
160
+ }).when (serverListener ).onDataRead (any (ChannelHandlerContext .class ), eq (3 ),
161
+ any (ByteBuf .class ), eq (0 ), Mockito .anyBoolean ());
148
162
try {
149
163
// Initialize the data latch based on the number of bytes expected.
150
164
requestLatch (new CountDownLatch (2 ));
@@ -163,18 +177,18 @@ public void run() {
163
177
assertTrue (dataLatch .await (5 , TimeUnit .SECONDS ));
164
178
165
179
// Verify that headers were received and only one DATA frame was received with endStream set.
166
- final ArgumentCaptor <ByteBuf > dataCaptor = ArgumentCaptor .forClass (ByteBuf .class );
167
180
verify (serverListener ).onHeadersRead (any (ChannelHandlerContext .class ), eq (3 ), eq (headers ), eq (0 ),
168
181
eq ((short ) 16 ), eq (false ), eq (0 ), eq (false ));
169
- verify (serverListener ).onDataRead (any (ChannelHandlerContext .class ), eq (3 ), dataCaptor . capture ( ), eq (0 ),
182
+ verify (serverListener ).onDataRead (any (ChannelHandlerContext .class ), eq (3 ), any ( ByteBuf . class ), eq (0 ),
170
183
eq (true ));
171
184
172
185
// Verify we received all the bytes.
173
- capturedData = dataCaptor .getAllValues ();
174
- assertEquals (data , capturedData .get (0 ));
186
+ out .flush ();
187
+ byte [] received = out .toByteArray ();
188
+ assertArrayEquals (bytes , received );
175
189
} finally {
176
190
data .release ();
177
- release ( capturedData );
191
+ out . close ( );
178
192
}
179
193
}
180
194
@@ -183,48 +197,56 @@ public void stressTest() throws Exception {
183
197
final Http2Headers headers = dummyHeaders ();
184
198
final String text = "hello world" ;
185
199
final String pingMsg = "12345678" ;
186
- final ByteBuf data = Unpooled .copiedBuffer (text .getBytes ());
187
- final ByteBuf pingData = Unpooled .copiedBuffer (pingMsg .getBytes ());
188
- List <ByteBuf > capturedData = null ;
189
- List <ByteBuf > capturedPingData = null ;
200
+ final ByteBuf data = Unpooled .copiedBuffer (text , UTF_8 );
201
+ final ByteBuf pingData = Unpooled .copiedBuffer (pingMsg , UTF_8 );
202
+ final List <String > receivedPingBuffers = new ArrayList <String >(NUM_STREAMS );
203
+ doAnswer (new Answer <Void >() {
204
+ @ Override
205
+ public Void answer (InvocationOnMock in ) throws Throwable {
206
+ receivedPingBuffers .add (((ByteBuf ) in .getArguments ()[1 ]).toString (UTF_8 ));
207
+ return null ;
208
+ }
209
+ }).when (serverListener ).onPingRead (any (ChannelHandlerContext .class ), eq (pingData ));
210
+ final List <String > receivedDataBuffers = new ArrayList <String >();
211
+ doAnswer (new Answer <Void >() {
212
+ @ Override
213
+ public Void answer (InvocationOnMock in ) throws Throwable {
214
+ receivedDataBuffers .add (((ByteBuf ) in .getArguments ()[2 ]).toString (UTF_8 ));
215
+ return null ;
216
+ }
217
+ }).when (serverListener ).onDataRead (any (ChannelHandlerContext .class ), anyInt (), eq (data ),
218
+ eq (0 ), eq (true ));
190
219
try {
191
220
runInChannel (clientChannel , new Http2Runnable () {
192
221
@ Override
193
222
public void run () {
194
223
for (int i = 0 , nextStream = 3 ; i < NUM_STREAMS ; ++i , nextStream += 2 ) {
195
224
http2Client .writeHeaders (ctx (), nextStream , headers , 0 , (short ) 16 , false , 0 , false ,
196
225
newPromise ());
197
- http2Client .writePing (ctx (), pingData .retain (), newPromise ());
198
- http2Client .writeData (ctx (), nextStream , data .retain (), 0 , true , newPromise ());
226
+ http2Client .writePing (ctx (), pingData .slice (). retain (), newPromise ());
227
+ http2Client .writeData (ctx (), nextStream , data .slice (). retain (), 0 , true , newPromise ());
199
228
}
200
229
}
201
230
});
202
231
// Wait for all frames to be received.
203
232
assertTrue (requestLatch .await (STRESS_TIMEOUT_SECONDS , SECONDS ));
204
233
verify (serverListener , times (NUM_STREAMS )).onHeadersRead (any (ChannelHandlerContext .class ), anyInt (),
205
234
eq (headers ), eq (0 ), eq ((short ) 16 ), eq (false ), eq (0 ), eq (false ));
206
- final ArgumentCaptor <ByteBuf > dataCaptor = ArgumentCaptor .forClass (ByteBuf .class );
207
- final ArgumentCaptor <ByteBuf > pingDataCaptor = ArgumentCaptor .forClass (ByteBuf .class );
208
235
verify (serverListener , times (NUM_STREAMS )).onPingRead (any (ChannelHandlerContext .class ),
209
- pingDataCaptor .capture ());
210
- capturedPingData = pingDataCaptor .getAllValues ();
211
- verify (serverListener , times (NUM_STREAMS )).onDataRead (any (ChannelHandlerContext .class ), anyInt (),
212
- dataCaptor .capture (), eq (0 ), eq (true ));
213
- capturedData = dataCaptor .getAllValues ();
214
- data .resetReaderIndex ();
215
- pingData .resetReaderIndex ();
216
- int i ;
217
- for (i = 0 ; i < capturedPingData .size (); ++i ) {
218
- assertEquals (pingData , capturedPingData .get (i ));
236
+ any (ByteBuf .class ));
237
+ verify (serverListener , times (NUM_STREAMS )).onDataRead (any (ChannelHandlerContext .class ),
238
+ anyInt (), any (ByteBuf .class ), eq (0 ), eq (true ));
239
+ assertEquals (NUM_STREAMS , receivedPingBuffers .size ());
240
+ assertEquals (NUM_STREAMS , receivedDataBuffers .size ());
241
+ for (String receivedData : receivedDataBuffers ) {
242
+ assertEquals (text , receivedData );
219
243
}
220
- for (i = 0 ; i < capturedData . size (); ++ i ) {
221
- assertEquals (capturedData . get ( i ). toString ( CharsetUtil . UTF_8 ), data , capturedData . get ( i ) );
244
+ for (String receivedPing : receivedPingBuffers ) {
245
+ assertEquals (pingMsg , receivedPing );
222
246
}
223
247
} finally {
224
248
data .release ();
225
249
pingData .release ();
226
- release (capturedData );
227
- release (capturedPingData );
228
250
}
229
251
}
230
252
@@ -250,14 +272,6 @@ private ChannelPromise newPromise() {
250
272
return ctx ().newPromise ();
251
273
}
252
274
253
- private static void release (List <ByteBuf > capturedData ) {
254
- if (capturedData != null ) {
255
- for (int i = 0 ; i < capturedData .size (); ++i ) {
256
- capturedData .get (i ).release ();
257
- }
258
- }
259
- }
260
-
261
275
private Http2Headers dummyHeaders () {
262
276
return new DefaultHttp2Headers ().method (as ("GET" )).scheme (as ("https" ))
263
277
.authority (as ("example.org" )).path (as ("/some/path/resource2" )).add (randomString (), randomString ());
0 commit comments