Skip to content

Commit

Permalink
fix: read multiple messages from a single stream
Browse files Browse the repository at this point in the history
  • Loading branch information
madadam committed Mar 11, 2021
1 parent 015333a commit 949fc4b
Show file tree
Hide file tree
Showing 2 changed files with 139 additions and 114 deletions.
229 changes: 125 additions & 104 deletions src/connections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,11 @@ use super::{
wire_msg::WireMsg,
};
use bytes::Bytes;
use futures::stream::StreamExt;
use futures::{future, stream::StreamExt};
use log::{error, trace, warn};
use std::net::SocketAddr;
use tokio::sync::mpsc::UnboundedSender;
use tokio::{
select,
time::{timeout, Duration},
};
use tokio::time::{timeout, Duration};

/// Connection instance to a node which can be used to send messages to it
#[derive(Clone)]
Expand Down Expand Up @@ -192,122 +189,146 @@ pub(super) fn listen_for_incoming_messages(
) {
let src = *remover.remote_addr();
let _ = tokio::spawn(async move {
loop {
let message: Option<Bytes> = select! {
bytes = next_on_uni_streams(&mut uni_streams) => bytes,
bytes = next_on_bi_streams(&mut bi_streams, src) => bytes,
};
if let Some(message) = message {
// When the message in handled internally we return Bytes::new() to prevent
// connection termination
if !message.is_empty() {
let _ = message_tx.send((src, message));
}
} else {
log::trace!("The connection has been terminated.");
let _ = disconnection_tx.send(*remover.remote_addr());
remover.remove();
break;
}
}
Ok::<_, Error>(())
let _ = future::join(
read_on_uni_streams(&mut uni_streams, src, message_tx.clone()),
read_on_bi_streams(&mut bi_streams, src, message_tx),
)
.await;

log::trace!("The connection has been terminated.");
let _ = disconnection_tx.send(src);
remover.remove();
});
}

// Returns next message sent by peer in an unidirectional stream.
async fn next_on_uni_streams(uni_streams: &mut quinn::IncomingUniStreams) -> Option<Bytes> {
match uni_streams.next().await {
None => None,
Some(Err(quinn::ConnectionError::ApplicationClosed { .. })) => {
trace!("Connection terminated by peer.");
None
}
Some(Err(err)) => {
warn!("Failed to read incoming message on uni-stream: {}", err);
None
}
Some(Ok(mut recv)) => match read_bytes(&mut recv).await {
Ok(WireMsg::UserMsg(bytes)) => Some(bytes),
Ok(msg) => {
error!("Unexpected message type: {:?}", msg);
Some(Bytes::new())
// Read messages sent by peer in an unidirectional stream.
async fn read_on_uni_streams(
uni_streams: &mut quinn::IncomingUniStreams,
peer_addr: SocketAddr,
message_tx: UnboundedSender<(SocketAddr, Bytes)>,
) {
while let Some(result) = uni_streams.next().await {
match result {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
trace!("Connection terminated by peer.");
break;
}
Err(err) => {
error!("{}", err);
Some(Bytes::new())
warn!("Failed to read incoming message on uni-stream: {}", err);
break;
}
},
Ok(mut recv) => loop {
match read_bytes(&mut recv).await {
Ok(WireMsg::UserMsg(bytes)) => {
let _ = message_tx.send((peer_addr, bytes));
}
Ok(msg) => error!("Unexpected message type: {:?}", msg),
Err(Error::StreamRead(quinn::ReadExactError::FinishedEarly)) => break,
Err(err) => {
error!("Failed reading from a uni-stream: {}", err);
break;
}
}
},
}
}
}

// Returns next message sent by peer in a bidirectional stream.
async fn next_on_bi_streams(
// Read messages sent by peer in a bidirectional stream.
async fn read_on_bi_streams(
bi_streams: &mut quinn::IncomingBiStreams,
peer_addr: SocketAddr,
) -> Option<Bytes> {
match bi_streams.next().await {
None => None,
Some(Err(quinn::ConnectionError::ApplicationClosed { .. })) => {
trace!("Connection terminated by peer.");
None
}
Some(Err(err)) => {
warn!("Failed to read incoming message on bi-stream: {}", err);
None
}
Some(Ok((mut send, mut recv))) => match read_bytes(&mut recv).await {
Ok(WireMsg::UserMsg(bytes)) => Some(bytes),
Ok(WireMsg::EndpointEchoReq) => {
trace!("Received Echo Request");
let message = WireMsg::EndpointEchoResp(peer_addr);
message.write_to_stream(&mut send).await.ok()?;
trace!("Responded to Echo request");
Some(Bytes::new())
}
Ok(WireMsg::EndpointVerificationReq(address_sent)) => {
trace!(
"Received Endpoint verification request {:?} from {:?}",
address_sent,
peer_addr
);
// Verify if the peer's endpoint is reachable via EchoServiceReq
let qp2p = QuicP2p::with_config(Default::default(), &[], false).ok()?;
let (temporary_endpoint, _, _, _) = qp2p.new_endpoint().await.ok()?;
let (mut temp_send, mut temp_recv) = temporary_endpoint
.open_bidirectional_stream(&address_sent)
.await
.ok()?;
let message = WireMsg::EndpointEchoReq;
message
.write_to_stream(&mut temp_send.quinn_send_stream)
.await
.ok()?;
let verified = matches!(
timeout(
Duration::from_secs(30),
WireMsg::read_from_stream(&mut temp_recv.quinn_recv_stream)
)
.await,
Ok(Ok(WireMsg::EndpointEchoResp(_)))
);

let message = WireMsg::EndpointVerficationResp(verified);
message.write_to_stream(&mut send).await.ok()?;
trace!("Responded to Endpoint verification request");
Some(Bytes::new())
}
Ok(msg) => {
error!("Unexpected message type: {:?}", msg);
Some(Bytes::new())
message_tx: UnboundedSender<(SocketAddr, Bytes)>,
) {
while let Some(result) = bi_streams.next().await {
match result {
Err(quinn::ConnectionError::ApplicationClosed { .. }) => {
trace!("Connection terminated by peer.");
break;
}
Err(err) => {
error!("{}", err);
Some(Bytes::new())
warn!("Failed to read incoming message on bi-stream: {}", err);
break;
}
},
Ok((mut send, mut recv)) => loop {
match read_bytes(&mut recv).await {
Ok(WireMsg::UserMsg(bytes)) => {
let _ = message_tx.send((peer_addr, bytes));
}
Ok(WireMsg::EndpointEchoReq) => {
if let Err(error) = handle_endpoint_echo_req(peer_addr, &mut send).await {
error!("Failed to handle Echo Request: {}", error);
}
}
Ok(WireMsg::EndpointVerificationReq(address_sent)) => {
if let Err(error) =
handle_endpoint_verification_req(peer_addr, address_sent, &mut send)
.await
{
error!("Failed to handle Endpoint verification request: {}", error);
}
}
Ok(msg) => {
error!("Unexpected message type: {:?}", msg);
}
Err(Error::StreamRead(quinn::ReadExactError::FinishedEarly)) => break,
Err(err) => {
error!("Failed reading from a bi-stream: {}", err);
break;
}
}
},
}
}
}

async fn handle_endpoint_echo_req(
peer_addr: SocketAddr,
send_stream: &mut quinn::SendStream,
) -> Result<()> {
trace!("Received Echo Request");
let message = WireMsg::EndpointEchoResp(peer_addr);
message.write_to_stream(send_stream).await?;
trace!("Responded to Echo request");
Ok(())
}

async fn handle_endpoint_verification_req(
peer_addr: SocketAddr,
addr_sent: SocketAddr,
send_stream: &mut quinn::SendStream,
) -> Result<()> {
trace!(
"Received Endpoint verification request {:?} from {:?}",
addr_sent,
peer_addr
);
// Verify if the peer's endpoint is reachable via EchoServiceReq
let qp2p = QuicP2p::with_config(Default::default(), &[], false)?;
let (temporary_endpoint, _, _, _) = qp2p.new_endpoint().await?;
let (mut temp_send, mut temp_recv) = temporary_endpoint
.open_bidirectional_stream(&addr_sent)
.await?;
let message = WireMsg::EndpointEchoReq;
message
.write_to_stream(&mut temp_send.quinn_send_stream)
.await?;
let verified = matches!(
timeout(
Duration::from_secs(30),
WireMsg::read_from_stream(&mut temp_recv.quinn_recv_stream)
)
.await,
Ok(Ok(WireMsg::EndpointEchoResp(_)))
);

let message = WireMsg::EndpointVerficationResp(verified);
message.write_to_stream(send_stream).await?;
trace!("Responded to Endpoint verification request");

Ok(())
}

#[cfg(test)]
mod tests {
use anyhow::anyhow;
Expand Down
24 changes: 14 additions & 10 deletions src/tests/common.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{new_qp2p, random_msg};
use crate::utils;
use anyhow::{anyhow, format_err, Result};
use anyhow::{anyhow, Result};
use futures::future;
use std::time::Duration;
use tokio::time::timeout;
Expand Down Expand Up @@ -364,7 +364,7 @@ async fn many_messages() -> Result<()> {

utils::init_logging();

let num_messages: usize = 101;
let num_messages: usize = 10_000;

let qp2p = new_qp2p()?;
let (send_endpoint, _, _, _) = qp2p.new_endpoint().await?;
Expand All @@ -388,21 +388,25 @@ async fn many_messages() -> Result<()> {
endpoint.send_message(msg, &recv_addr).await?;
log::info!("sent {}", id);

Ok(())
Ok::<_, anyhow::Error>(())
}
}));
}

// Receiver
tasks.push(tokio::spawn({
async move {
for _ in 0..num_messages {
if let Some((src, msg)) = recv_incoming_messages.next().await {
let id = usize::from_le_bytes(msg[..].try_into().unwrap());
assert_eq!(src, send_addr);
log::info!("received {}", id);
} else {
return Err(format_err!("incoming messages stream closed unexpectedly"));
let mut num_received = 0;

while let Some((src, msg)) = recv_incoming_messages.next().await {
let id = usize::from_le_bytes(msg[..].try_into().unwrap());
assert_eq!(src, send_addr);
log::info!("received {}", id);

num_received += 1;

if num_received >= num_messages {
break;
}
}

Expand Down

0 comments on commit 949fc4b

Please sign in to comment.