Skip to content

Commit

Permalink
feat!: update quinn to 0.9
Browse files Browse the repository at this point in the history
  • Loading branch information
b-zee authored and bochaco committed Dec 12, 2022
1 parent 8a2c455 commit 891b45e
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 88 deletions.
5 changes: 2 additions & 3 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ name = "p2p_node"
bincode = "1.2.1"
bytes = { version = "1.0.1", features = ["serde"] }
futures = "~0.3.8"
quinn = { version = "0.8.0", default-features = false, features = ["tls-rustls", "ring"] }
quinn-proto = "0.8.0"
quinn = { version = "0.9", default-features = false, features = ["tls-rustls", "ring", "runtime-tokio"] }
quinn-proto = { version = "0.9", default-features = false, features = ["tls-rustls", "ring"] }
rcgen = "~0.9"
serde = { version = "1.0.117", features = ["derive"] }
thiserror = "1.0.23"
Expand All @@ -36,4 +36,3 @@ tiny-keccak = { version = "2.0.2", features = ["sha3"] }
tokio = { version = "1.12.0", features = ["macros", "rt-multi-thread"] }
tracing-subscriber = "0.2.19"
tracing-test = "0.1.0"
quinn = { version = "0.8.0", default-features = false, features = ["tls-rustls", "native-certs"] }
2 changes: 1 addition & 1 deletion src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ impl InternalConfig {
server.transport = transport.clone();

let mut client = quinn::ClientConfig::new(Arc::new(client_crypto));
client.transport = transport;
let _ = client.transport_config(transport);

Ok(Self {
client,
Expand Down
198 changes: 115 additions & 83 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ pub struct Connection {
impl Connection {
pub(crate) fn new(
endpoint: quinn::Endpoint,
connection: quinn::NewConnection,
connection: quinn::Connection,
) -> (Connection, ConnectionIncoming) {
// this channel serves to keep the background message listener alive so long as one side of
// the connection API is alive.
let (alive_tx, alive_rx) = watch::channel(());
let alive_tx = Arc::new(alive_tx);
let peer_address = connection.connection.remote_address();
let peer_address = connection.remote_address();
let conn = Self {
inner: connection.connection,
inner: connection,
_alive_tx: Arc::clone(&alive_tx),
};
let conn_id = conn.id();
Expand All @@ -60,8 +60,7 @@ impl Connection {
endpoint,
conn_id,
peer_address,
connection.uni_streams,
connection.bi_streams,
connection,
alive_tx,
alive_rx,
),
Expand Down Expand Up @@ -304,8 +303,7 @@ impl ConnectionIncoming {
endpoint: quinn::Endpoint,
conn_id: String,
peer_addr: SocketAddr,
uni_streams: quinn::IncomingUniStreams,
bi_streams: quinn::IncomingBiStreams,
connection: quinn::Connection,
alive_tx: Arc<watch::Sender<()>>,
alive_rx: watch::Receiver<()>,
) -> Self {
Expand All @@ -314,13 +312,7 @@ impl ConnectionIncoming {
// offload the actual message handling to a background task - the task will exit when
// `alive_tx` is dropped, which would be when both sides of the connection are dropped.
start_message_listeners(
endpoint,
conn_id,
peer_addr,
uni_streams,
bi_streams,
alive_rx,
message_tx,
endpoint, conn_id, peer_addr, connection, alive_rx, message_tx,
);

Self {
Expand Down Expand Up @@ -355,26 +347,25 @@ fn start_message_listeners(
endpoint: quinn::Endpoint,
conn_id: String,
peer_addr: SocketAddr,
uni_streams: quinn::IncomingUniStreams,
bi_streams: quinn::IncomingBiStreams,
connection: quinn::Connection,
alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<(UsrMsgBytes, Option<ResponseStream>), RecvError>>,
) {
let _ = tokio::spawn(listen_on_uni_streams(
peer_addr,
uni_streams,
connection.clone(),
alive_rx.clone(),
message_tx.clone(),
));

let _ = tokio::spawn(listen_on_bi_streams(
endpoint, conn_id, peer_addr, bi_streams, alive_rx, message_tx,
endpoint, conn_id, peer_addr, connection, alive_rx, message_tx,
));
}

async fn listen_on_uni_streams(
peer_addr: SocketAddr,
uni_streams: quinn::IncomingUniStreams,
connection: quinn::Connection,
mut alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<(UsrMsgBytes, Option<ResponseStream>), RecvError>>,
) {
Expand All @@ -383,25 +374,38 @@ async fn listen_on_uni_streams(
peer_addr
);

let mut uni_messages = Box::pin(
uni_streams
.map_ok(|recv_stream| {
trace!("Handling incoming uni-stream from {}", peer_addr);
// Turn the `accept_uni` method into a stream that yields an `Some(Err(ConnectionError))` before `None`
let uni_streams = futures::stream::unfold(Some(connection), |connection| async move {
let connection = match connection {
Some(c) => c,
None => return None,
};

stream::try_unfold(recv_stream, |mut recv_stream| async move {
WireMsg::read_from_stream(&mut recv_stream)
.await
.and_then(|msg| match msg {
Some(WireMsg::UserMsg(msg)) => Ok(Some((msg, recv_stream))),
Some(other_msg) => {
Err(RecvError::UnexpectedMsgReceived(other_msg.to_string()))
}
None => Ok(None),
})
})
match connection.accept_uni().await {
Ok(recv) => Some((Ok(recv), Some(connection))),
Err(err) => {
let err: ConnectionError = err.into();
Some((Err(err), None))
}
}
});

let uni_messages = uni_streams
.map_ok(|recv_stream| {
trace!("Handling incoming uni-stream from {}", peer_addr);

stream::try_unfold(recv_stream, |mut recv_stream| async move {
WireMsg::read_from_stream(&mut recv_stream)
.await
.and_then(|msg| match msg {
Some(WireMsg::UserMsg(msg)) => Ok(Some((msg, recv_stream))),
Some(msg) => Err(RecvError::UnexpectedMsgReceived(msg.to_string())),
None => Ok(None),
})
})
.try_flatten(),
);
})
.try_flatten();
let mut uni_messages = Box::pin(uni_messages);

// it's a shame to allocate, but there are `Pin` errors otherwise – and we should only be doing
// this once (per connection).
Expand Down Expand Up @@ -451,11 +455,28 @@ async fn listen_on_bi_streams(
endpoint: quinn::Endpoint,
conn_id: String,
peer_addr: SocketAddr,
bi_streams: quinn::IncomingBiStreams,
connection: quinn::Connection,
mut alive_rx: watch::Receiver<()>,
message_tx: mpsc::Sender<Result<(UsrMsgBytes, Option<ResponseStream>), RecvError>>,
) {
trace!("Started listener for incoming bi-streams from {peer_addr}");

// Turn the `accept_bi` method into a stream that yields an `Some(Err(ConnectionError))` before `None`
let bi_streams = futures::stream::unfold(Some(connection), |connection| async move {
let connection = match connection {
Some(c) => c,
None => return None,
};

match connection.accept_bi().await {
Ok(recv) => Some((Ok(recv), Some(connection))),
Err(err) => {
let err: ConnectionError = err.into();
Some((Err(err), None))
}
}
});

let streaming = bi_streams.try_for_each_concurrent(None, |(send_stream, mut recv_stream)| {
let endpoint = &endpoint;
let message_tx = &message_tx;
Expand Down Expand Up @@ -509,6 +530,7 @@ async fn listen_on_bi_streams(
// it's a shame to allocate, but there are `Pin` errors otherwise – and we should only be doing
// this once.
let mut alive = Box::pin(alive_rx.changed());
let streaming = Box::pin(streaming);

match future::select(streaming, &mut alive).await {
future::Either::Left((Ok(()), _)) => {
Expand Down Expand Up @@ -553,10 +575,10 @@ async fn handle_endpoint_verification(
.map_err(ConnectionError::from)?
.await?;

let (mut send_stream, mut recv_stream) = connection.connection.open_bi().await?;
let (mut send_stream, mut recv_stream) = connection.open_bi().await?;
trace!(
"EndpointVerificationReq: sending EndpointEchoReq to {addr} over connection {}",
connection.connection.stable_id()
connection.stable_id()
);
WireMsg::EndpointEchoReq
.write_to_stream(&mut send_stream)
Expand Down Expand Up @@ -600,7 +622,7 @@ mod tests {
};
use bytes::Bytes;
use color_eyre::eyre::{bail, Result};
use futures::{StreamExt, TryStreamExt};
use futures::future::OptionFuture;
use quinn::Endpoint as QuinnEndpoint;
use std::time::Duration;

Expand All @@ -609,24 +631,26 @@ mod tests {
async fn basic_usage() -> Result<()> {
let config = InternalConfig::try_from_config(Default::default())?;

let (mut peer1, _peer1_incoming) =
QuinnEndpoint::server(config.server.clone(), local_addr())?;
let mut peer1 = QuinnEndpoint::server(config.server.clone(), local_addr())?;
peer1.set_default_client_config(config.client);

let (peer2, peer2_incoming) = QuinnEndpoint::server(config.server.clone(), local_addr())?;
let peer2 = QuinnEndpoint::server(config.server.clone(), local_addr())?;

{
let (p1_tx, mut p1_rx) = Connection::new(
peer1.clone(),
peer1.connect(peer2.local_addr()?, SERVER_NAME)?.await?,
);

let (p2_tx, mut p2_rx) =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
let (p2_tx, mut p2_rx) = if let Some(connection) =
timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
{
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};

p1_tx
.open_uni()
Expand Down Expand Up @@ -668,24 +692,26 @@ mod tests {
..Default::default()
})?;

let (mut peer1, _peer1_incoming) =
QuinnEndpoint::server(config.server.clone(), local_addr())?;
let mut peer1 = QuinnEndpoint::server(config.server.clone(), local_addr())?;
peer1.set_default_client_config(config.client);

let (peer2, peer2_incoming) = QuinnEndpoint::server(config.server.clone(), local_addr())?;
let peer2 = QuinnEndpoint::server(config.server.clone(), local_addr())?;

// open a connection between the two peers
let (p1_tx, _) = Connection::new(
peer1.clone(),
peer1.connect(peer2.local_addr()?, SERVER_NAME)?.await?,
);

let (_, mut p2_rx) =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
let (_, mut p2_rx) = if let Some(connection) =
timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
{
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};

// let 2 * idle timeout pass
tokio::time::sleep(Duration::from_secs(2)).await;
Expand All @@ -712,11 +738,10 @@ mod tests {
async fn test_endpoint_echo() -> Result<()> {
let config = InternalConfig::try_from_config(Config::default())?;

let (mut peer1, _peer1_incoming) =
QuinnEndpoint::server(config.server.clone(), local_addr())?;
let mut peer1 = QuinnEndpoint::server(config.server.clone(), local_addr())?;
peer1.set_default_client_config(config.client);

let (peer2, peer2_incoming) = QuinnEndpoint::server(config.server.clone(), local_addr())?;
let peer2 = QuinnEndpoint::server(config.server.clone(), local_addr())?;

{
let (p1_tx, _) = Connection::new(
Expand All @@ -725,12 +750,15 @@ mod tests {
);

// we need to accept the connection on p2, or the message won't be processed
let _p2_handle =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
let _p2_handle = if let Some(connection) =
timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
{
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};

let (mut send_stream, mut recv_stream) = p1_tx.open_bi().await?;
send_stream.send_wire_msg(WireMsg::EndpointEchoReq).await?;
Expand Down Expand Up @@ -761,12 +789,10 @@ mod tests {
async fn endpoint_verification() -> Result<()> {
let config = InternalConfig::try_from_config(Default::default())?;

let (mut peer1, peer1_incoming) =
QuinnEndpoint::server(config.server.clone(), local_addr())?;
let mut peer1 = QuinnEndpoint::server(config.server.clone(), local_addr())?;
peer1.set_default_client_config(config.client.clone());

let (mut peer2, peer2_incoming) =
QuinnEndpoint::server(config.server.clone(), local_addr())?;
let mut peer2 = QuinnEndpoint::server(config.server.clone(), local_addr())?;
peer2.set_default_client_config(config.client);

{
Expand All @@ -776,25 +802,31 @@ mod tests {
);

// we need to accept the connection on p2, or the message won't be processed
let _p2_handle =
if let Some(connection) = timeout(peer2_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
let _p2_handle = if let Some(connection) =
timeout(OptionFuture::from(peer2.accept().await))
.await?
.and_then(|c| c.ok())
{
Connection::new(peer2.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};

let (mut send_stream, mut recv_stream) = p1_tx.open_bi().await?;
send_stream
.send_wire_msg(WireMsg::EndpointVerificationReq(peer1.local_addr()?))
.await?;

// we need to accept the connection on p1, or the message won't be processed
let _p1_handle =
if let Some(connection) = timeout(peer1_incoming.then(|c| c).try_next()).await?? {
Connection::new(peer1.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};
let _p1_handle = if let Some(connection) =
timeout(OptionFuture::from(peer1.accept().await))
.await?
.and_then(|c| c.ok())
{
Connection::new(peer1.clone(), connection)
} else {
bail!("did not receive incoming connection when one was expected");
};

if let Some(msg) = timeout(recv_stream.next_wire_msg()).await?? {
if let WireMsg::EndpointVerificationResp(true) = msg {
Expand Down
Loading

0 comments on commit 891b45e

Please sign in to comment.