Skip to content

Commit 817e111

Browse files
committed
feat(h2): implement CONNECT support (fixes #2508)
1 parent d631b8c commit 817e111

File tree

7 files changed

+285
-35
lines changed

7 files changed

+285
-35
lines changed

Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ include = [
2222
[lib]
2323
crate-type = ["lib", "staticlib", "cdylib"]
2424

25+
[patch.crates-io]
26+
h2 = { git = "https://github.com/hyperium/h2.git", branch = "master" }
27+
2528
[dependencies]
2629
bytes = "1"
2730
futures-core = { version = "0.3", default-features = false }
@@ -31,7 +34,7 @@ http = "0.2"
3134
http-body = "0.4"
3235
httpdate = "1.0"
3336
httparse = "1.4"
34-
h2 = { version = "0.3", optional = true }
37+
h2 = { version = "0.3.2", optional = true }
3538
itoa = "0.4.1"
3639
tracing = { version = "0.1", default-features = false, features = ["std"] }
3740
pin-project = "1.0"

src/body/length.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,17 @@ use std::fmt;
33
#[derive(Clone, Copy, PartialEq, Eq)]
44
pub(crate) struct DecodedLength(u64);
55

6+
#[cfg(any(feature = "http1", feature = "http2"))]
7+
impl From<Option<u64>> for DecodedLength {
8+
fn from(len: Option<u64>) -> Self {
9+
len.and_then(|len| {
10+
// If the length is u64::MAX, oh well, just reported chunked.
11+
Self::checked_new(len).ok()
12+
})
13+
.unwrap_or(DecodedLength::CHUNKED)
14+
}
15+
}
16+
617
#[cfg(any(feature = "http1", feature = "http2", test))]
718
const MAX_LEN: u64 = std::u64::MAX - 2;
819

src/error.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ pub(super) enum User {
9090
/// User tried to send a certain header in an unexpected context.
9191
///
9292
/// For example, sending both `content-length` and `transfer-encoding`.
93-
#[cfg(feature = "http1")]
93+
#[cfg(any(feature = "http1", feature = "http2"))]
9494
#[cfg(feature = "server")]
9595
UnexpectedHeader,
9696
/// User tried to create a Request with bad version.
@@ -279,7 +279,7 @@ impl Error {
279279
Error::new(Kind::User(user))
280280
}
281281

282-
#[cfg(feature = "http1")]
282+
#[cfg(any(feature = "http1", feature = "http2"))]
283283
#[cfg(feature = "server")]
284284
pub(super) fn new_user_header() -> Error {
285285
Error::new_user(User::UnexpectedHeader)
@@ -394,7 +394,7 @@ impl Error {
394394
Kind::User(User::MakeService) => "error from user's MakeService",
395395
#[cfg(any(feature = "http1", feature = "http2"))]
396396
Kind::User(User::Service) => "error from user's Service",
397-
#[cfg(feature = "http1")]
397+
#[cfg(any(feature = "http1", feature = "http2"))]
398398
#[cfg(feature = "server")]
399399
Kind::User(User::UnexpectedHeader) => "user sent unexpected header",
400400
#[cfg(any(feature = "http1", feature = "http2"))]

src/proto/h2/mod.rs

Lines changed: 132 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
use bytes::Buf;
2-
use h2::SendStream;
1+
use bytes::{Buf, Bytes};
2+
use h2::{RecvStream, SendStream};
33
use http::header::{
44
HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER,
55
TRANSFER_ENCODING, UPGRADE,
66
};
77
use http::HeaderMap;
88
use pin_project::pin_project;
9+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
910
use std::error::Error as StdError;
10-
use std::io::IoSlice;
11+
use std::io::{self, Cursor, IoSlice};
12+
use std::task::Context;
1113

1214
use crate::body::{DecodedLength, HttpBody};
1315
use crate::common::{task, Future, Pin, Poll};
1416
use crate::headers::content_length_parse_all;
17+
use crate::proto::h2::ping::Recorder;
1518

1619
pub(crate) mod ping;
1720

@@ -84,12 +87,7 @@ fn strip_connection_headers(headers: &mut HeaderMap, is_request: bool) {
8487
}
8588

8689
fn decode_content_length(headers: &HeaderMap) -> DecodedLength {
87-
if let Some(len) = content_length_parse_all(headers) {
88-
// If the length is u64::MAX, oh well, just reported chunked.
89-
DecodedLength::checked_new(len).unwrap_or_else(|_| DecodedLength::CHUNKED)
90-
} else {
91-
DecodedLength::CHUNKED
92-
}
90+
content_length_parse_all(headers).into()
9391
}
9492

9593
// body adapters used by both Client and Server
@@ -172,7 +170,7 @@ where
172170
is_eos,
173171
);
174172

175-
let buf = SendBuf(Some(chunk));
173+
let buf = SendBuf::Buf(chunk);
176174
me.body_tx
177175
.send_data(buf, is_eos)
178176
.map_err(crate::Error::new_body_write)?;
@@ -243,32 +241,149 @@ impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {
243241

244242
fn send_eos_frame(&mut self) -> crate::Result<()> {
245243
trace!("send body eos");
246-
self.send_data(SendBuf(None), true)
244+
self.send_data(SendBuf::None, true)
247245
.map_err(crate::Error::new_body_write)
248246
}
249247
}
250248

251-
struct SendBuf<B>(Option<B>);
249+
enum SendBuf<B> {
250+
Buf(B),
251+
Cursor(Cursor<Box<[u8]>>),
252+
None,
253+
}
252254

253255
impl<B: Buf> Buf for SendBuf<B> {
254256
#[inline]
255257
fn remaining(&self) -> usize {
256-
self.0.as_ref().map(|b| b.remaining()).unwrap_or(0)
258+
match *self {
259+
Self::Buf(ref b) => b.remaining(),
260+
Self::Cursor(ref c) => c.remaining(),
261+
Self::None => 0,
262+
}
257263
}
258264

259265
#[inline]
260266
fn chunk(&self) -> &[u8] {
261-
self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[])
267+
match *self {
268+
Self::Buf(ref b) => b.chunk(),
269+
Self::Cursor(ref c) => c.chunk(),
270+
Self::None => &[],
271+
}
262272
}
263273

264274
#[inline]
265275
fn advance(&mut self, cnt: usize) {
266-
if let Some(b) = self.0.as_mut() {
267-
b.advance(cnt)
276+
match *self {
277+
Self::Buf(ref mut b) => b.advance(cnt),
278+
Self::Cursor(ref mut c) => c.advance(cnt),
279+
Self::None => {},
268280
}
269281
}
270282

271283
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
272-
self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0)
284+
match *self {
285+
Self::Buf(ref b) => b.chunks_vectored(dst),
286+
Self::Cursor(ref c) => c.chunks_vectored(dst),
287+
Self::None => 0,
288+
}
289+
}
290+
}
291+
292+
struct H2Upgraded<B>
293+
where
294+
B: Buf,
295+
{
296+
ping: Recorder,
297+
send_stream: SendStream<SendBuf<B>>,
298+
recv_stream: RecvStream,
299+
buf: Bytes,
300+
}
301+
302+
impl<B> AsyncRead for H2Upgraded<B>
303+
where
304+
B: Buf,
305+
{
306+
fn poll_read(
307+
mut self: Pin<&mut Self>,
308+
cx: &mut Context<'_>,
309+
read_buf: &mut ReadBuf<'_>,
310+
) -> Poll<Result<(), io::Error>> {
311+
if self.buf.is_empty() {
312+
self.buf = match ready!(self.recv_stream.poll_data(cx)) {
313+
None => return Poll::Ready(Ok(())),
314+
Some(Ok(buf)) => {
315+
self.ping.record_data(buf.len());
316+
buf
317+
}
318+
Some(Err(e)) => {
319+
return Poll::Ready(Err(h2_to_io_error(e)));
320+
}
321+
};
322+
}
323+
let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
324+
read_buf.put_slice(&self.buf[..cnt]);
325+
self.buf.advance(cnt);
326+
let _ = self.recv_stream.flow_control().release_capacity(cnt);
327+
Poll::Ready(Ok(()))
328+
}
329+
}
330+
331+
impl<B> AsyncWrite for H2Upgraded<B>
332+
where
333+
B: Buf,
334+
{
335+
fn poll_write(
336+
mut self: Pin<&mut Self>,
337+
cx: &mut Context<'_>,
338+
buf: &[u8],
339+
) -> Poll<Result<usize, io::Error>> {
340+
if buf.is_empty() {
341+
return Poll::Ready(Ok(0));
342+
}
343+
// FIXME(nox): PipeToSendStream does some weird stuff, first reserving
344+
// one byte and then polling reset if the capacity is 0, should we do
345+
// that here too? Should we poll reset somewhere?
346+
self.send_stream.reserve_capacity(buf.len());
347+
Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) {
348+
None => Ok(0),
349+
Some(Ok(cnt)) => self.write(&buf[..cnt], false).map(|()| cnt),
350+
Some(Err(e)) => {
351+
// FIXME(nox): Should all H2 errors be returned as is with a
352+
// ErrorKind::Other, or should some be special-cased, say for
353+
// example, CANCEL?
354+
Err(h2_to_io_error(e))
355+
},
356+
})
357+
}
358+
359+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
360+
Poll::Ready(Ok(()))
361+
}
362+
363+
fn poll_shutdown(
364+
mut self: Pin<&mut Self>,
365+
_cx: &mut Context<'_>,
366+
) -> Poll<Result<(), io::Error>> {
367+
Poll::Ready(self.write(&[], true))
368+
}
369+
}
370+
371+
impl<B> H2Upgraded<B>
372+
where
373+
B: Buf,
374+
{
375+
fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
376+
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
377+
self.send_stream
378+
.send_data(send_buf, end_of_stream)
379+
.map_err(h2_to_io_error)
380+
}
381+
}
382+
383+
fn h2_to_io_error(e: h2::Error) -> io::Error {
384+
if e.is_io() {
385+
e.into_io().unwrap()
386+
} else {
387+
io::Error::new(io::ErrorKind::Other, e)
273388
}
274389
}

0 commit comments

Comments
 (0)