Skip to content

Commit a68864d

Browse files
committed
feat(h2): implement CONNECT support (fixes #2508)
1 parent bff1dde commit a68864d

File tree

7 files changed

+544
-35
lines changed

7 files changed

+544
-35
lines changed

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ http = "0.2"
3131
http-body = "0.4"
3232
httpdate = "1.0"
3333
httparse = "1.4"
34-
h2 = { version = "0.3", optional = true }
34+
h2 = { version = "0.3.3", optional = true }
3535
itoa = "0.4.1"
3636
tracing = { version = "0.1", default-features = false, features = ["std"] }
3737
pin-project = "1.0"

src/body/length.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@ use std::fmt;
33
#[derive(Clone, Copy, PartialEq, Eq)]
44
pub(crate) struct DecodedLength(u64);
55

6+
#[cfg(any(feature = "http1", feature = "http2"))]
7+
impl From<Option<u64>> for DecodedLength {
8+
fn from(len: Option<u64>) -> Self {
9+
len.and_then(|len| {
10+
// If the length is u64::MAX, oh well, just reported chunked.
11+
Self::checked_new(len).ok()
12+
})
13+
.unwrap_or(DecodedLength::CHUNKED)
14+
}
15+
}
16+
617
#[cfg(any(feature = "http1", feature = "http2", test))]
718
const MAX_LEN: u64 = std::u64::MAX - 2;
819

src/error.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub(super) enum User {
9090
/// User tried to send a certain header in an unexpected context.
9191
///
9292
/// For example, sending both `content-length` and `transfer-encoding`.
93-
#[cfg(feature = "http1")]
93+
#[cfg(any(feature = "http1", feature = "http2"))]
9494
#[cfg(feature = "server")]
9595
UnexpectedHeader,
9696
/// User tried to create a Request with bad version.
@@ -279,7 +279,7 @@ impl Error {
279279
Error::new(Kind::User(user))
280280
}
281281

282-
#[cfg(feature = "http1")]
282+
#[cfg(any(feature = "http1", feature = "http2"))]
283283
#[cfg(feature = "server")]
284284
pub(super) fn new_user_header() -> Error {
285285
Error::new_user(User::UnexpectedHeader)
@@ -394,7 +394,7 @@ impl Error {
394394
Kind::User(User::MakeService) => "error from user's MakeService",
395395
#[cfg(any(feature = "http1", feature = "http2"))]
396396
Kind::User(User::Service) => "error from user's Service",
397-
#[cfg(feature = "http1")]
397+
#[cfg(any(feature = "http1", feature = "http2"))]
398398
#[cfg(feature = "server")]
399399
Kind::User(User::UnexpectedHeader) => "user sent unexpected header",
400400
#[cfg(any(feature = "http1", feature = "http2"))]

src/proto/h2/mod.rs

Lines changed: 135 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
use bytes::Buf;
2-
use h2::SendStream;
1+
use bytes::{Buf, Bytes};
2+
use h2::{RecvStream, SendStream};
33
use http::header::{
44
HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER,
55
TRANSFER_ENCODING, UPGRADE,
66
};
77
use http::HeaderMap;
88
use pin_project::pin_project;
9+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
910
use std::error::Error as StdError;
10-
use std::io::IoSlice;
11+
use std::io::{self, Cursor, IoSlice};
12+
use std::task::Context;
1113

1214
use crate::body::{DecodedLength, HttpBody};
1315
use crate::common::{task, Future, Pin, Poll};
1416
use crate::headers::content_length_parse_all;
17+
use crate::proto::h2::ping::Recorder;
1518

1619
pub(crate) mod ping;
1720

@@ -84,12 +87,7 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) {
8487
}
8588

8689
fn decode_content_length(headers: &HeaderMap) -> DecodedLength {
87-
if let Some(len) = content_length_parse_all(headers) {
88-
// If the length is u64::MAX, oh well, just reported chunked.
89-
DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED)
90-
} else {
91-
DecodedLength::CHUNKED
92-
}
90+
content_length_parse_all(headers).into()
9391
}
9492

9593
// body adapters used by both Client and Server
@@ -172,7 +170,7 @@ where
172170
is_eos,
173171
);
174172

175-
let buf = SendBuf(Some(chunk));
173+
let buf = SendBuf::Buf(chunk);
176174
me.body_tx
177175
.send_data(buf, is_eos)
178176
.map_err(crate::Error::new_body_write)?;
@@ -243,32 +241,152 @@ impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {
243241

244242
fn send_eos_frame(&mut self) -> crate::Result<()> {
245243
trace!("send body eos");
246-
self.send_data(SendBuf(None), true)
244+
self.send_data(SendBuf::None, true)
247245
.map_err(crate::Error::new_body_write)
248246
}
249247
}
250248

251-
struct SendBuf<B>(Option<B>);
249+
enum SendBuf<B> {
250+
Buf(B),
251+
Cursor(Cursor<Box<[u8]>>),
252+
None,
253+
}
252254

253255
impl<B: Buf> Buf for SendBuf<B> {
254256
#[inline]
255257
fn remaining(&self) -> usize {
256-
self.0.as_ref().map(|b| b.remaining()).unwrap_or(0)
258+
match *self {
259+
Self::Buf(ref b) => b.remaining(),
260+
Self::Cursor(ref c) => c.remaining(),
261+
Self::None => 0,
262+
}
257263
}
258264

259265
#[inline]
260266
fn chunk(&self) -> &[u8] {
261-
self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[])
267+
match *self {
268+
Self::Buf(ref b) => b.chunk(),
269+
Self::Cursor(ref c) => c.chunk(),
270+
Self::None => &[],
271+
}
262272
}
263273

264274
#[inline]
265275
fn advance(&mut self, cnt: usize) {
266-
if let Some(b) = self.0.as_mut() {
267-
b.advance(cnt)
276+
match *self {
277+
Self::Buf(ref mut b) => b.advance(cnt),
278+
Self::Cursor(ref mut c) => c.advance(cnt),
279+
Self::None => {},
268280
}
269281
}
270282

271283
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
272-
self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0)
284+
match *self {
285+
Self::Buf(ref b) => b.chunks_vectored(dst),
286+
Self::Cursor(ref c) => c.chunks_vectored(dst),
287+
Self::None => 0,
288+
}
289+
}
290+
}
291+
292+
struct H2Upgraded<B>
293+
where
294+
B: Buf,
295+
{
296+
ping: Recorder,
297+
send_stream: SendStream<SendBuf<B>>,
298+
recv_stream: RecvStream,
299+
buf: Bytes,
300+
}
301+
302+
impl<B> AsyncRead for H2Upgraded<B>
303+
where
304+
B: Buf,
305+
{
306+
fn poll_read(
307+
mut self: Pin<&mut Self>,
308+
cx: &mut Context<'_>,
309+
read_buf: &mut ReadBuf<'_>,
310+
) -> Poll<Result<(), io::Error>> {
311+
if self.buf.is_empty() {
312+
self.buf = loop {
313+
match ready!(self.recv_stream.poll_data(cx)) {
314+
None => return Poll::Ready(Ok(())),
315+
Some(Ok(buf)) if buf.is_empty() && !self.recv_stream.is_end_stream() => continue,
316+
Some(Ok(buf)) => {
317+
self.ping.record_data(buf.len());
318+
break buf;
319+
}
320+
Some(Err(e)) => {
321+
return Poll::Ready(Err(h2_to_io_error(e)));
322+
}
323+
}
324+
};
325+
}
326+
let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
327+
read_buf.put_slice(&self.buf[..cnt]);
328+
self.buf.advance(cnt);
329+
let _ = self.recv_stream.flow_control().release_capacity(cnt);
330+
Poll::Ready(Ok(()))
331+
}
332+
}
333+
334+
impl<B> AsyncWrite for H2Upgraded<B>
335+
where
336+
B: Buf,
337+
{
338+
fn poll_write(
339+
mut self: Pin<&mut Self>,
340+
cx: &mut Context<'_>,
341+
buf: &[u8],
342+
) -> Poll<Result<usize, io::Error>> {
343+
if buf.is_empty() {
344+
return Poll::Ready(Ok(0));
345+
}
346+
// FIXME(nox): PipeToSendStream does some weird stuff, first reserving
347+
// one byte and then polling reset if the capacity is 0, should we do
348+
// that here too? Should we poll reset somewhere?
349+
self.send_stream.reserve_capacity(buf.len());
350+
Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) {
351+
None => Ok(0),
352+
Some(Ok(cnt)) => self.write(&buf[..cnt], false).map(|()| cnt),
353+
Some(Err(e)) => {
354+
// FIXME(nox): Should all H2 errors be returned as is with a
355+
// ErrorKind::Other, or should some be special-cased, say for
356+
// example, CANCEL?
357+
Err(h2_to_io_error(e))
358+
},
359+
})
360+
}
361+
362+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
363+
Poll::Ready(Ok(()))
364+
}
365+
366+
fn poll_shutdown(
367+
mut self: Pin<&mut Self>,
368+
_cx: &mut Context<'_>,
369+
) -> Poll<Result<(), io::Error>> {
370+
Poll::Ready(self.write(&[], true))
371+
}
372+
}
373+
374+
impl<B> H2Upgraded<B>
375+
where
376+
B: Buf,
377+
{
378+
fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
379+
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
380+
self.send_stream
381+
.send_data(send_buf, end_of_stream)
382+
.map_err(h2_to_io_error)
383+
}
384+
}
385+
386+
fn h2_to_io_error(e: h2::Error) -> io::Error {
387+
if e.is_io() {
388+
e.into_io().unwrap()
389+
} else {
390+
io::Error::new(io::ErrorKind::Other, e)
273391
}
274392
}

src/proto/h2/server.rs

Lines changed: 64 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,19 +3,24 @@ use std::marker::Unpin;
33
#[cfg(feature = "runtime")]
44
use std::time::Duration;
55

6+
use bytes::Bytes;
67
use h2::server::{Connection, Handshake, SendResponse};
7-
use h2::Reason;
8+
use h2::{Reason, RecvStream};
9+
use http::{Method, Request};
810
use pin_project::pin_project;
911
use tokio::io::{AsyncRead, AsyncWrite};
1012

11-
use super::{decode_content_length, ping, PipeToSendStream, SendBuf};
13+
use super::{ping, PipeToSendStream, SendBuf};
1214
use crate::body::HttpBody;
1315
use crate::common::exec::ConnStreamExec;
1416
use crate::common::{date, task, Future, Pin, Poll};
1517
use crate::headers;
18+
use crate::proto::h2::ping::Recorder;
19+
use crate::proto::h2::H2Upgraded;
1620
use crate::proto::Dispatched;
1721
use crate::service::HttpService;
1822

23+
use crate::upgrade::{OnUpgrade, Pending, Upgraded};
1924
use crate::{Body, Response};
2025

2126
// Our defaults are chosen for the "majority" case, which usually are not
@@ -257,9 +262,9 @@ where
257262

258263
// When the service is ready, accepts an incoming request.
259264
match ready!(self.conn.poll_accept(cx)) {
260-
Some(Ok((req, respond))) => {
265+
Some(Ok((req, mut respond))) => {
261266
trace!("incoming request");
262-
let content_length = decode_content_length(req.headers());
267+
let content_length = headers::content_length_parse_all(req.headers());
263268
let ping = self
264269
.ping
265270
.as_ref()
@@ -269,8 +274,32 @@ where
269274
// Record the headers received
270275
ping.record_non_data();
271276

272-
let req = req.map(|stream| crate::Body::h2(stream, content_length, ping));
273-
let fut = H2Stream::new(service.call(req), respond);
277+
let is_connect = req.method() == Method::CONNECT;
278+
let (mut parts, stream) = req.into_parts();
279+
let (req, connect_parts) = if !is_connect {
280+
(
281+
Request::from_parts(
282+
parts,
283+
crate::Body::h2(stream, content_length.into(), ping),
284+
),
285+
None,
286+
)
287+
} else {
288+
if content_length.map_or(false, |len| len != 0) {
289+
warn!("h2 connect request with non-zero body not supported");
290+
respond.send_reset(h2::Reason::INTERNAL_ERROR);
291+
return Poll::Ready(Ok(()));
292+
}
293+
let (pending, upgrade) = crate::upgrade::pending();
294+
debug_assert!(parts.extensions.get::<OnUpgrade>().is_none());
295+
parts.extensions.insert(upgrade);
296+
(
297+
Request::from_parts(parts, crate::Body::empty()),
298+
Some((pending, ping, stream)),
299+
)
300+
};
301+
302+
let fut = H2Stream::new(service.call(req), connect_parts, respond);
274303
exec.execute_h2stream(fut);
275304
}
276305
Some(Err(e)) => {
@@ -333,18 +362,22 @@ enum H2StreamState<F, B>
333362
where
334363
B: HttpBody,
335364
{
336-
Service(#[pin] F),
365+
Service(#[pin] F, Option<(Pending, Recorder, RecvStream)>),
337366
Body(#[pin] PipeToSendStream<B>),
338367
}
339368

340369
impl<F, B> H2Stream<F, B>
341370
where
342371
B: HttpBody,
343372
{
344-
fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> {
373+
fn new(
374+
fut: F,
375+
connect_parts: Option<(Pending, Recorder, RecvStream)>,
376+
respond: SendResponse<SendBuf<B::Data>>,
377+
) -> H2Stream<F, B> {
345378
H2Stream {
346379
reply: respond,
347-
state: H2StreamState::Service(fut),
380+
state: H2StreamState::Service(fut, connect_parts),
348381
}
349382
}
350383
}
@@ -374,7 +407,7 @@ where
374407
let mut me = self.project();
375408
loop {
376409
let next = match me.state.as_mut().project() {
377-
H2StreamStateProj::Service(h) => {
410+
H2StreamStateProj::Service(h, connect_parts) => {
378411
let res = match h.poll(cx) {
379412
Poll::Ready(Ok(r)) => r,
380413
Poll::Pending => {
@@ -405,6 +438,27 @@ where
405438
.entry(::http::header::DATE)
406439
.or_insert_with(date::update_and_header_value);
407440

441+
if let Some((pending, ping, recv_stream)) = connect_parts.take() {
442+
if res.status().is_success() {
443+
if headers::content_length_parse_all(res.headers()).map_or(false, |len| len != 0) {
444+
warn!("h2 successful response to CONNECT request with body not supported");
445+
me.reply.send_reset(h2::Reason::INTERNAL_ERROR);
446+
return Poll::Ready(Err(crate::Error::new_user_header()));
447+
}
448+
let send_stream = reply!(me, res, false);
449+
pending.fulfill(Upgraded::new(
450+
H2Upgraded {
451+
ping,
452+
recv_stream,
453+
send_stream,
454+
buf: Bytes::new(),
455+
},
456+
Bytes::new(),
457+
));
458+
return Poll::Ready(Ok(()));
459+
}
460+
}
461+
408462
// automatically set Content-Length from body...
409463
if let Some(len) = body.size_hint().exact() {
410464
headers::set_content_length_if_missing(res.headers_mut(), len);

0 commit comments

Comments
 (0)