Skip to content

Commit d6ba136

Browse files
committed
feat(h2): implement CONNECT support (fixes #2508)
1 parent d631b8c commit d6ba136

File tree

5 files changed

+263
-23
lines changed

5 files changed

+263
-23
lines changed

Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ include = [
2222
[lib]
2323
crate-type = ["lib", "staticlib", "cdylib"]
2424

25+
[patch.crates-io]
26+
h2 = { git = "https://github.com/hyperium/h2.git", branch = "master" }
27+
2528
[dependencies]
2629
bytes = "1"
2730
futures-core = { version = "0.3", default-features = false }
@@ -31,7 +34,7 @@ http = "0.2"
3134
http-body = "0.4"
3235
httpdate = "1.0"
3336
httparse = "1.4"
34-
h2 = { version = "0.3", optional = true }
37+
h2 = { version = "0.3.2", optional = true }
3538
itoa = "0.4.1"
3639
tracing = { version = "0.1", default-features = false, features = ["std"] }
3740
pin-project = "1.0"

src/proto/h2/mod.rs

Lines changed: 137 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,20 @@
1-
use bytes::Buf;
2-
use h2::SendStream;
1+
use bytes::{Buf, Bytes};
2+
use h2::{RecvStream, SendStream};
33
use http::header::{
44
HeaderName, CONNECTION, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TE, TRAILER,
55
TRANSFER_ENCODING, UPGRADE,
66
};
77
use http::HeaderMap;
88
use pin_project::pin_project;
9+
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
910
use std::error::Error as StdError;
10-
use std::io::IoSlice;
11+
use std::io::{self, Cursor, IoSlice};
12+
use std::task::Context;
1113

1214
use crate::body::{DecodedLength, HttpBody};
1315
use crate::common::{task, Future, Pin, Poll};
1416
use crate::headers::content_length_parse_all;
17+
use crate::proto::h2::ping::Recorder;
1518

1619
pub(crate) mod ping;
1720

@@ -172,7 +175,7 @@ where
172175
is_eos,
173176
);
174177

175-
let buf = SendBuf(Some(chunk));
178+
let buf = SendBuf::Buf(chunk);
176179
me.body_tx
177180
.send_data(buf, is_eos)
178181
.map_err(crate::Error::new_body_write)?;
@@ -243,32 +246,155 @@ impl<B: Buf> SendStreamExt for SendStream<SendBuf<B>> {
243246

244247
fn send_eos_frame(&mut self) -> crate::Result<()> {
245248
trace!("send body eos");
246-
self.send_data(SendBuf(None), true)
249+
self.send_data(SendBuf::None, true)
247250
.map_err(crate::Error::new_body_write)
248251
}
249252
}
250253

251-
struct SendBuf<B>(Option<B>);
254+
enum SendBuf<B> {
255+
Buf(B),
256+
Cursor(Cursor<Box<[u8]>>),
257+
None,
258+
}
252259

253260
impl<B: Buf> Buf for SendBuf<B> {
254261
#[inline]
255262
fn remaining(&self) -> usize {
256-
self.0.as_ref().map(|b| b.remaining()).unwrap_or(0)
263+
match *self {
264+
Self::Buf(ref b) => b.remaining(),
265+
Self::Cursor(ref c) => c.remaining(),
266+
Self::None => 0,
267+
}
257268
}
258269

259270
#[inline]
260271
fn chunk(&self) -> &[u8] {
261-
self.0.as_ref().map(|b| b.chunk()).unwrap_or(&[])
272+
match *self {
273+
Self::Buf(ref b) => b.chunk(),
274+
Self::Cursor(ref c) => c.chunk(),
275+
Self::None => &[],
276+
}
262277
}
263278

264279
#[inline]
265280
fn advance(&mut self, cnt: usize) {
266-
if let Some(b) = self.0.as_mut() {
267-
b.advance(cnt)
281+
match *self {
282+
Self::Buf(ref mut b) => b.advance(cnt),
283+
Self::Cursor(ref mut c) => c.advance(cnt),
284+
Self::None => {},
268285
}
269286
}
270287

271288
fn chunks_vectored<'a>(&'a self, dst: &mut [IoSlice<'a>]) -> usize {
272-
self.0.as_ref().map(|b| b.chunks_vectored(dst)).unwrap_or(0)
289+
match *self {
290+
Self::Buf(ref b) => b.chunks_vectored(dst),
291+
Self::Cursor(ref c) => c.chunks_vectored(dst),
292+
Self::None => 0,
293+
}
294+
}
295+
}
296+
297+
// FIXME(nox): Should this type be public? I'm asking this because
298+
// the HTTP/2 RFC says that a proxy that encounters a TCP error with the
299+
// upstream peer should send back to the client a stream error with reason
300+
// CONNECT_ERROR, so we need *something* to send that, but all the user
301+
// gets is a hyper::upgrade::Upgraded, so you can't send anything but a
302+
// data frame back.
303+
struct H2Upgraded<B>
304+
where
305+
B: Buf,
306+
{
307+
ping: Recorder,
308+
send_stream: SendStream<SendBuf<B>>,
309+
recv_stream: RecvStream,
310+
buf: Bytes,
311+
}
312+
313+
impl<B> AsyncRead for H2Upgraded<B>
314+
where
315+
B: Buf,
316+
{
317+
fn poll_read(
318+
mut self: Pin<&mut Self>,
319+
cx: &mut Context<'_>,
320+
read_buf: &mut ReadBuf<'_>,
321+
) -> Poll<Result<(), io::Error>> {
322+
if self.buf.is_empty() {
323+
self.buf = match ready!(self.recv_stream.poll_data(cx)) {
324+
None => return Poll::Ready(Ok(())),
325+
Some(Ok(buf)) => {
326+
self.ping.record_data(buf.len());
327+
buf
328+
}
329+
Some(Err(e)) => {
330+
return Poll::Ready(Err(h2_to_io_error(e)));
331+
}
332+
};
333+
}
334+
let cnt = std::cmp::min(self.buf.len(), read_buf.remaining());
335+
read_buf.put_slice(&self.buf[..cnt]);
336+
self.buf.advance(cnt);
337+
let _ = self.recv_stream.flow_control().release_capacity(cnt);
338+
Poll::Ready(Ok(()))
339+
}
340+
}
341+
342+
impl<B> AsyncWrite for H2Upgraded<B>
343+
where
344+
B: Buf,
345+
{
346+
fn poll_write(
347+
mut self: Pin<&mut Self>,
348+
cx: &mut Context<'_>,
349+
buf: &[u8],
350+
) -> Poll<Result<usize, io::Error>> {
351+
if buf.is_empty() {
352+
return Poll::Ready(Ok(0));
353+
}
354+
// FIXME(nox): PipeToSendStream does some weird stuff, first reserving
355+
// one byte and then polling reset if the capacity is 0, should we do
356+
// that here too? Should we poll reset somewhere?
357+
self.send_stream.reserve_capacity(buf.len());
358+
Poll::Ready(match ready!(self.send_stream.poll_capacity(cx)) {
359+
None => Ok(0),
360+
Some(Ok(cnt)) => self.write(&buf[..cnt], false).map(|()| cnt),
361+
Some(Err(e)) => {
362+
// FIXME(nox): Should all H2 errors be returned as is with a
363+
// ErrorKind::Other, or should some be special-cased, say for
364+
// example, CANCEL?
365+
Err(h2_to_io_error(e))
366+
},
367+
})
368+
}
369+
370+
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
371+
Poll::Ready(Ok(()))
372+
}
373+
374+
fn poll_shutdown(
375+
mut self: Pin<&mut Self>,
376+
_cx: &mut Context<'_>,
377+
) -> Poll<Result<(), io::Error>> {
378+
Poll::Ready(self.write(&[], true))
379+
}
380+
}
381+
382+
impl<B> H2Upgraded<B>
383+
where
384+
B: Buf,
385+
{
386+
fn write(&mut self, buf: &[u8], end_of_stream: bool) -> Result<(), io::Error> {
387+
let send_buf = SendBuf::Cursor(Cursor::new(buf.into()));
388+
self.send_stream
389+
.send_data(send_buf, end_of_stream)
390+
.map_err(h2_to_io_error)
391+
}
392+
}
393+
394+
fn h2_to_io_error(e: h2::Error) -> io::Error {
395+
if e.is_io() {
396+
e.into_io().unwrap()
397+
} else {
398+
io::Error::new(io::ErrorKind::Other, e)
273399
}
274400
}

src/proto/h2/server.rs

Lines changed: 51 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ use std::marker::Unpin;
33
#[cfg(feature = "runtime")]
44
use std::time::Duration;
55

6+
use bytes::Bytes;
67
use h2::server::{Connection, Handshake, SendResponse};
7-
use h2::Reason;
8+
use h2::{Reason, RecvStream};
9+
use http::{Method, Request};
810
use pin_project::pin_project;
911
use tokio::io::{AsyncRead, AsyncWrite};
1012

@@ -13,9 +15,12 @@ use crate::body::HttpBody;
1315
use crate::common::exec::ConnStreamExec;
1416
use crate::common::{date, task, Future, Pin, Poll};
1517
use crate::headers;
18+
use crate::proto::h2::ping::Recorder;
19+
use crate::proto::h2::H2Upgraded;
1620
use crate::proto::Dispatched;
1721
use crate::service::HttpService;
1822

23+
use crate::upgrade::{OnUpgrade, Pending, Upgraded};
1924
use crate::{Body, Response};
2025

2126
// Our defaults are chosen for the "majority" case, which usually are not
@@ -269,8 +274,28 @@ where
269274
// Record the headers received
270275
ping.record_non_data();
271276

272-
let req = req.map(|stream| crate::Body::h2(stream, content_length, ping));
273-
let fut = H2Stream::new(service.call(req), respond);
277+
let is_connect = req.method() == Method::CONNECT;
278+
let (mut parts, stream) = req.into_parts();
279+
let (req, connect_parts) = if !is_connect {
280+
(
281+
Request::from_parts(
282+
parts,
283+
crate::Body::h2(stream, content_length, ping),
284+
),
285+
None,
286+
)
287+
} else {
288+
// FIXME(nox): What happens to the request body? Should we check `content_length`?
289+
let (pending, upgrade) = crate::upgrade::pending();
290+
debug_assert!(parts.extensions.get::<OnUpgrade>().is_none());
291+
parts.extensions.insert(upgrade);
292+
(
293+
Request::from_parts(parts, crate::Body::empty()),
294+
Some((pending, ping, stream)),
295+
)
296+
};
297+
298+
let fut = H2Stream::new(service.call(req), connect_parts, respond);
274299
exec.execute_h2stream(fut);
275300
}
276301
Some(Err(e)) => {
@@ -333,18 +358,22 @@ enum H2StreamState<F, B>
333358
where
334359
B: HttpBody,
335360
{
336-
Service(#[pin] F),
361+
Service(#[pin] F, Option<(Pending, Recorder, RecvStream)>),
337362
Body(#[pin] PipeToSendStream<B>),
338363
}
339364

340365
impl<F, B> H2Stream<F, B>
341366
where
342367
B: HttpBody,
343368
{
344-
fn new(fut: F, respond: SendResponse<SendBuf<B::Data>>) -> H2Stream<F, B> {
369+
fn new(
370+
fut: F,
371+
connect_parts: Option<(Pending, Recorder, RecvStream)>,
372+
respond: SendResponse<SendBuf<B::Data>>,
373+
) -> H2Stream<F, B> {
345374
H2Stream {
346375
reply: respond,
347-
state: H2StreamState::Service(fut),
376+
state: H2StreamState::Service(fut, connect_parts),
348377
}
349378
}
350379
}
@@ -374,7 +403,7 @@ where
374403
let mut me = self.project();
375404
loop {
376405
let next = match me.state.as_mut().project() {
377-
H2StreamStateProj::Service(h) => {
406+
H2StreamStateProj::Service(h, connect_parts) => {
378407
let res = match h.poll(cx) {
379408
Poll::Ready(Ok(r)) => r,
380409
Poll::Pending => {
@@ -405,6 +434,21 @@ where
405434
.entry(::http::header::DATE)
406435
.or_insert_with(date::update_and_header_value);
407436

437+
if let Some((pending, ping, recv_stream)) = connect_parts.take() {
438+
// FIXME(nox): What do we do about the response body? AFAIK h1 returns an error.
439+
let send_stream = reply!(me, res, false);
440+
pending.fulfill(Upgraded::new(
441+
H2Upgraded {
442+
ping,
443+
recv_stream,
444+
send_stream,
445+
buf: Bytes::new(),
446+
},
447+
Bytes::new(),
448+
));
449+
return Poll::Ready(Ok(()));
450+
}
451+
408452
// automatically set Content-Length from body...
409453
if let Some(len) = body.size_hint().exact() {
410454
headers::set_content_length_if_missing(res.headers_mut(), len);

src/upgrade.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ pub fn on<T: sealed::CanUpgrade>(msg: T) -> OnUpgrade {
6262
msg.on_upgrade()
6363
}
6464

65-
#[cfg(feature = "http1")]
65+
#[cfg(any(feature = "http1", feature = "http2"))]
6666
pub(super) struct Pending {
6767
tx: oneshot::Sender<crate::Result<Upgraded>>,
6868
}
6969

70-
#[cfg(feature = "http1")]
70+
#[cfg(any(feature = "http1", feature = "http2"))]
7171
pub(super) fn pending() -> (Pending, OnUpgrade) {
7272
let (tx, rx) = oneshot::channel();
7373
(Pending { tx }, OnUpgrade { rx: Some(rx) })
@@ -76,7 +76,7 @@ pub(super) fn pending() -> (Pending, OnUpgrade) {
7676
// ===== impl Upgraded =====
7777

7878
impl Upgraded {
79-
#[cfg(any(feature = "http1", test))]
79+
#[cfg(any(feature = "http1", feature = "http2", test))]
8080
pub(super) fn new<T>(io: T, read_buf: Bytes) -> Self
8181
where
8282
T: AsyncRead + AsyncWrite + Unpin + Send + 'static,
@@ -187,13 +187,14 @@ impl fmt::Debug for OnUpgrade {
187187

188188
// ===== impl Pending =====
189189

190-
#[cfg(feature = "http1")]
190+
#[cfg(any(feature = "http1", feature = "http2"))]
191191
impl Pending {
192192
pub(super) fn fulfill(self, upgraded: Upgraded) {
193193
trace!("pending upgrade fulfill");
194194
let _ = self.tx.send(Ok(upgraded));
195195
}
196196

197+
#[cfg(feature = "http1")]
197198
/// Don't fulfill the pending Upgrade, but instead signal that
198199
/// upgrades are handled manually.
199200
pub(super) fn manual(self) {

0 commit comments

Comments
 (0)