@@ -2,17 +2,21 @@ use std::error::Error as StdError;
2
2
#[ cfg( feature = "runtime" ) ]
3
3
use std:: time:: Duration ;
4
4
5
+ use bytes:: { Bytes } ;
5
6
use futures_channel:: { mpsc, oneshot} ;
6
7
use futures_util:: future:: { self , Either , FutureExt as _, TryFutureExt as _} ;
7
8
use futures_util:: stream:: StreamExt as _;
8
9
use h2:: client:: { Builder , SendRequest } ;
10
+ use http:: Method ;
9
11
use tokio:: io:: { AsyncRead , AsyncWrite } ;
10
12
11
- use super :: { decode_content_length , ping , PipeToSendStream , SendBuf } ;
13
+ use super :: { ping , H2Upgraded , PipeToSendStream , SendBuf } ;
12
14
use crate :: body:: HttpBody ;
13
15
use crate :: common:: { exec:: Exec , task, Future , Never , Pin , Poll } ;
14
16
use crate :: headers;
15
17
use crate :: proto:: Dispatched ;
18
+ use crate :: proto:: h2:: UpgradedSendStream ;
19
+ use crate :: upgrade:: Upgraded ;
16
20
use crate :: { Body , Request , Response } ;
17
21
18
22
type ClientRx < B > = crate :: client:: dispatch:: Receiver < Request < B > , Response < Body > > ;
@@ -233,8 +237,20 @@ where
233
237
headers:: set_content_length_if_missing ( req. headers_mut ( ) , len) ;
234
238
}
235
239
}
240
+
241
+ let is_connect = req. method ( ) == Method :: CONNECT ;
236
242
let eos = body. is_end_stream ( ) ;
237
- let ( fut, body_tx) = match self . h2_tx . send_request ( req, eos) {
243
+ let ping = self . ping . clone ( ) ;
244
+
245
+ if is_connect {
246
+ if headers:: content_length_parse_all ( req. headers ( ) ) . map_or ( false , |len| len != 0 ) {
247
+ warn ! ( "h2 connect request with non-zero body not supported" ) ;
248
+ cb. send ( Err ( ( crate :: Error :: new_h2 ( h2:: Reason :: INTERNAL_ERROR . into ( ) ) , None ) ) ) ;
249
+ continue ;
250
+ }
251
+ }
252
+
253
+ let ( fut, body_tx) = match self . h2_tx . send_request ( req, !is_connect && eos) {
238
254
Ok ( ok) => ok,
239
255
Err ( err) => {
240
256
debug ! ( "client send request error: {}" , err) ;
@@ -243,45 +259,76 @@ where
243
259
}
244
260
} ;
245
261
246
- let ping = self . ping . clone ( ) ;
247
- if !eos {
248
- let mut pipe = Box :: pin ( PipeToSendStream :: new ( body, body_tx) ) . map ( |res| {
249
- if let Err ( e) = res {
250
- debug ! ( "client request body error: {}" , e) ;
251
- }
252
- } ) ;
253
-
254
- // eagerly see if the body pipe is ready and
255
- // can thus skip allocating in the executor
256
- match Pin :: new ( & mut pipe) . poll ( cx) {
257
- Poll :: Ready ( _) => ( ) ,
258
- Poll :: Pending => {
259
- let conn_drop_ref = self . conn_drop_ref . clone ( ) ;
260
- // keep the ping recorder's knowledge of an
261
- // "open stream" alive while this body is
262
- // still sending...
263
- let ping = ping. clone ( ) ;
264
- let pipe = pipe. map ( move |x| {
265
- drop ( conn_drop_ref) ;
266
- drop ( ping) ;
267
- x
262
+ let send_stream = if !is_connect {
263
+ if !eos {
264
+ let mut pipe =
265
+ Box :: pin ( PipeToSendStream :: new ( body, body_tx) ) . map ( |res| {
266
+ if let Err ( e) = res {
267
+ debug ! ( "client request body error: {}" , e) ;
268
+ }
268
269
} ) ;
269
- self . executor . execute ( pipe) ;
270
+
271
+ // eagerly see if the body pipe is ready and
272
+ // can thus skip allocating in the executor
273
+ match Pin :: new ( & mut pipe) . poll ( cx) {
274
+ Poll :: Ready ( _) => ( ) ,
275
+ Poll :: Pending => {
276
+ let conn_drop_ref = self . conn_drop_ref . clone ( ) ;
277
+ // keep the ping recorder's knowledge of an
278
+ // "open stream" alive while this body is
279
+ // still sending...
280
+ let ping = ping. clone ( ) ;
281
+ let pipe = pipe. map ( move |x| {
282
+ drop ( conn_drop_ref) ;
283
+ drop ( ping) ;
284
+ x
285
+ } ) ;
286
+ self . executor . execute ( pipe) ;
287
+ }
270
288
}
271
289
}
272
- }
290
+
291
+ None
292
+ } else {
293
+ Some ( body_tx)
294
+ } ;
273
295
274
296
let fut = fut. map ( move |result| match result {
275
297
Ok ( res) => {
276
298
// record that we got the response headers
277
299
ping. record_non_data ( ) ;
278
300
279
- let content_length = decode_content_length ( res. headers ( ) ) ;
280
- let res = res. map ( |stream| {
281
- let ping = ping. for_stream ( & stream) ;
282
- crate :: Body :: h2 ( stream, content_length, ping)
283
- } ) ;
284
- Ok ( res)
301
+ let content_length = headers:: content_length_parse_all ( res. headers ( ) ) ;
302
+ if let Some ( mut send_stream) = send_stream {
303
+ if content_length. map_or ( false , |len| len != 0 ) {
304
+ warn ! ( "h2 connect response with non-zero body not supported" ) ;
305
+
306
+ send_stream. send_reset ( h2:: Reason :: INTERNAL_ERROR ) ;
307
+ return Err ( ( crate :: Error :: new_h2 ( h2:: Reason :: INTERNAL_ERROR . into ( ) ) , None ) ) ;
308
+ }
309
+ let ( parts, recv_stream) = res. into_parts ( ) ;
310
+ let mut res = Response :: from_parts ( parts, Body :: empty ( ) ) ;
311
+
312
+ let ( pending, on_upgrade) = crate :: upgrade:: pending ( ) ;
313
+ let io = H2Upgraded {
314
+ ping,
315
+ send_stream : unsafe { UpgradedSendStream :: new ( send_stream) } ,
316
+ recv_stream,
317
+ buf : Bytes :: new ( ) ,
318
+ } ;
319
+ let upgraded = Upgraded :: new ( io, Bytes :: new ( ) ) ;
320
+
321
+ pending. fulfill ( upgraded) ;
322
+ res. extensions_mut ( ) . insert ( on_upgrade) ;
323
+
324
+ Ok ( res)
325
+ } else {
326
+ let res = res. map ( |stream| {
327
+ let ping = ping. for_stream ( & stream) ;
328
+ crate :: Body :: h2 ( stream, content_length. into ( ) , ping)
329
+ } ) ;
330
+ Ok ( res)
331
+ }
285
332
}
286
333
Err ( err) => {
287
334
ping. ensure_not_timed_out ( ) . map_err ( |e| ( e, None ) ) ?;
0 commit comments