Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(server): add AutoConnection #11

Merged
merged 14 commits into from
Sep 16, 2023
12 changes: 9 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@ edition = "2018"
publish = false # no accidents while in dev

[dependencies]
hyper = "=1.0.0-rc.1"
hyper = "1.0.0-rc.2"
futures-channel = "0.3"
futures-util = { version = "0.3", default-features = false }
http = "0.2"
http-body = "1.0.0-rc.2"
bytes = "1"

# Necessary to overcome msrv check of rust 1.49, as 1.15.0 failed
once_cell = "=1.14"
Expand All @@ -31,16 +33,20 @@ tower-service = "0.3"
tower = { version = "0.4", features = ["util"] }

[dev-dependencies]
hyper = { version = "1.0.0-rc.2", features = ["full"] }
http-body-util = "0.1.0-rc.2"
tokio = { version = "1", features = ["macros", "test-util"] }
tokio-test = "0.4"

[target.'cfg(any(target_os = "linux", target_os = "macos"))'.dev-dependencies]
pnet_datalink = "0.27.2"

[features]
runtime = []
tcp = []
http1 = []
http2 = []
http1 = ["hyper/http1"]
http2 = ["hyper/http2"]
auto = ["hyper/server", "http1", "http2"]

# internal features used in CI
__internal_happy_eyeballs_tests = []
1 change: 1 addition & 0 deletions src/common/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@ macro_rules! ready {
pub(crate) use ready;
pub(crate) mod exec;
pub(crate) mod never;
pub(crate) mod rewind;

pub(crate) use never::Never;
161 changes: 161 additions & 0 deletions src/common/rewind.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
use std::marker::Unpin;
use std::{cmp, io};

use bytes::{Buf, Bytes};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};

use std::{
pin::Pin,
task::{self, Poll},
};

/// Combine a buffer with an IO, rewinding reads to use the buffer.
#[derive(Debug)]
pub(crate) struct Rewind<T> {
pre: Option<Bytes>,
inner: T,
}

impl<T> Rewind<T> {
#[cfg(test)]
pub(crate) fn new(io: T) -> Self {
Rewind {
pre: None,
inner: io,
}
}

#[allow(dead_code)]
pub(crate) fn new_buffered(io: T, buf: Bytes) -> Self {
Rewind {
pre: Some(buf),
inner: io,
}
}

#[cfg(test)]
pub(crate) fn rewind(&mut self, bs: Bytes) {
debug_assert!(self.pre.is_none());
self.pre = Some(bs);
}

// pub(crate) fn into_inner(self) -> (T, Bytes) {
// (self.inner, self.pre.unwrap_or_else(Bytes::new))
// }

// pub(crate) fn get_mut(&mut self) -> &mut T {
// &mut self.inner
// }
}

impl<T> AsyncRead for Rewind<T>
where
T: AsyncRead + Unpin,
{
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
if let Some(mut prefix) = self.pre.take() {
// If there are no remaining bytes, let the bytes get dropped.
if !prefix.is_empty() {
let copy_len = cmp::min(prefix.len(), buf.remaining());
// TODO: There should be a way to do following two lines cleaner...
buf.put_slice(&prefix[..copy_len]);
prefix.advance(copy_len);
programatik29 marked this conversation as resolved.
Show resolved Hide resolved
// Put back what's left
if !prefix.is_empty() {
self.pre = Some(prefix);
}

return Poll::Ready(Ok(()));
}
}
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}

impl<T> AsyncWrite for Rewind<T>
where
T: AsyncWrite + Unpin,
{
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut task::Context<'_>,
bufs: &[io::IoSlice<'_>],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write_vectored(cx, bufs)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}

fn is_write_vectored(&self) -> bool {
self.inner.is_write_vectored()
}
}

#[cfg(test)]
mod tests {
// FIXME: re-implement tests with `async/await`, this import should
// trigger a warning to remind us
use super::Rewind;
use bytes::Bytes;
use tokio::io::AsyncReadExt;

#[cfg(not(miri))]
#[tokio::test]
async fn partial_rewind() {
let underlying = [104, 101, 108, 108, 111];

let mock = tokio_test::io::Builder::new().read(&underlying).build();

let mut stream = Rewind::new(mock);

// Read off some bytes, ensure we filled o1
let mut buf = [0; 2];
stream.read_exact(&mut buf).await.expect("read1");

// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");

// At this point we should have read everything that was in the MockStream
assert_eq!(&buf, &underlying);
}

#[cfg(not(miri))]
#[tokio::test]
async fn full_rewind() {
let underlying = [104, 101, 108, 108, 111];

let mock = tokio_test::io::Builder::new().read(&underlying).build();

let mut stream = Rewind::new(mock);

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");

// Rewind the stream so that it is as if we never read in the first place.
stream.rewind(Bytes::copy_from_slice(&buf[..]));

let mut buf = [0; 5];
stream.read_exact(&mut buf).await.expect("read1");
}
}
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
#![deny(missing_docs)]
#![cfg_attr(docsrs, feature(doc_auto_cfg, doc_cfg))]

//! hyper utilities
pub use crate::error::{GenericError, Result};

pub mod client;
pub mod common;
pub mod rt;
pub mod server;

mod error;
2 changes: 1 addition & 1 deletion src/rt/tokio_executor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use std::future::Future;

/// Future executor that utilises `tokio` threads.
#[non_exhaustive]
#[derive(Default, Debug)]
#[derive(Default, Debug, Clone, Copy)]
pub struct TokioExecutor {}

impl<Fut> Executor<Fut> for TokioExecutor
Expand Down
Loading