Skip to content

websocket: Fix connection stability on decrypt messages #393

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 22, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#![allow(clippy::single_match)]
#![allow(clippy::result_large_err)]
#![allow(clippy::large_enum_variant)]
#![allow(clippy::redundant_pattern_matching)]
#![allow(clippy::type_complexity)]
#![allow(clippy::result_unit_err)]
Expand Down Expand Up @@ -80,7 +81,7 @@ pub mod yamux;

mod bandwidth;
mod multistream_select;
mod utils;
pub mod utils;

#[cfg(test)]
mod mock;
Expand Down
2 changes: 1 addition & 1 deletion src/multistream_select/negotiated.rs
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,6 @@ impl From<NegotiationError> for io::Error {
if let NegotiationError::ProtocolError(e) = err {
return e.into();
}
io::Error::new(io::ErrorKind::Other, err)
io::Error::other(err)
}
}
140 changes: 26 additions & 114 deletions src/transport/websocket/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
//! Stream implementation for `tokio_tungstenite::WebSocketStream` that implements
//! `AsyncRead + AsyncWrite`

use bytes::{Buf, Bytes, BytesMut};
use bytes::{Buf, Bytes};
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_tungstenite::{tungstenite::Message, WebSocketStream};
Expand All @@ -31,121 +31,66 @@ use std::{
task::{Context, Poll},
};

const DEFAULT_BUF_SIZE: usize = 8 * 1024;

/// Send state.
#[derive(Debug)]
enum State {
/// State is poisoned.
Poisoned,

/// Sink is accepting input.
ReadyToSend,

/// Flush is pending for the sink.
FlushPending,
}
const LOG_TARGET: &str = "litep2p::transport::websocket::stream";

/// Buffered stream which implements `AsyncRead + AsyncWrite`
#[derive(Debug)]
pub(super) struct BufferedStream<S: AsyncRead + AsyncWrite + Unpin> {
/// Write buffer.
write_buffer: BytesMut,

/// Read buffer.
///
/// The buffer is taken directly from the WebSocket stream.
read_buffer: Bytes,

/// Underlying WebSocket stream.
stream: WebSocketStream<S>,

/// Read state.
state: State,
}

impl<S: AsyncRead + AsyncWrite + Unpin> BufferedStream<S> {
/// Create new [`BufferedStream`].
pub(super) fn new(stream: WebSocketStream<S>) -> Self {
Self {
write_buffer: BytesMut::with_capacity(DEFAULT_BUF_SIZE),
read_buffer: Bytes::new(),
stream,
state: State::ReadyToSend,
}
}
}

impl<S: AsyncRead + AsyncWrite + Unpin> futures::AsyncWrite for BufferedStream<S> {
fn poll_write(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<std::io::Result<usize>> {
self.write_buffer.extend_from_slice(buf);

Poll::Ready(Ok(buf.len()))
}
match futures::ready!(self.stream.poll_ready_unpin(cx)) {
Ok(()) => {
let message = Message::Binary(Bytes::copy_from_slice(buf));

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
if self.write_buffer.is_empty() {
return self
.stream
.poll_ready_unpin(cx)
.map_err(|_| std::io::ErrorKind::UnexpectedEof.into());
}

loop {
match std::mem::replace(&mut self.state, State::Poisoned) {
State::ReadyToSend => {
match self.stream.poll_ready_unpin(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(_error)) =>
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
Poll::Pending => {
self.state = State::ReadyToSend;
return Poll::Pending;
}
}

let message = std::mem::take(&mut self.write_buffer);
match self.stream.start_send_unpin(Message::Binary(message.freeze())) {
Ok(()) => {}
Err(_error) =>
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
}

// Transition to flush pending state.
self.state = State::FlushPending;
continue;
if let Err(err) = self.stream.start_send_unpin(message) {
tracing::debug!(target: LOG_TARGET, "Error during start send: {:?}", err);
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()));
}

State::FlushPending => {
match self.stream.poll_flush_unpin(cx) {
Poll::Ready(Ok(())) => {}
Poll::Ready(Err(_error)) =>
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
Poll::Pending => {
self.state = State::FlushPending;
return Poll::Pending;
}
}

self.state = State::ReadyToSend;
self.write_buffer = BytesMut::with_capacity(DEFAULT_BUF_SIZE);
return Poll::Ready(Ok(()));
}
State::Poisoned =>
return Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into())),
Poll::Ready(Ok(buf.len()))
}
Err(err) => {
tracing::debug!(target: LOG_TARGET, "Error during poll ready: {:?}", err);
Poll::Ready(Err(std::io::ErrorKind::UnexpectedEof.into()))
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
self.stream.poll_flush_unpin(cx).map_err(|err| {
tracing::debug!(target: LOG_TARGET, "Error during poll flush: {:?}", err);
std::io::ErrorKind::UnexpectedEof.into()
})
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
match futures::ready!(self.stream.poll_close_unpin(cx)) {
Ok(_) => Poll::Ready(Ok(())),
Err(_) => Poll::Ready(Err(std::io::ErrorKind::PermissionDenied.into())),
}
self.stream.poll_close_unpin(cx).map_err(|err| {
tracing::debug!(target: LOG_TARGET, "Error during poll close: {:?}", err);
std::io::ErrorKind::PermissionDenied.into()
})
}
}

Expand Down Expand Up @@ -183,7 +128,7 @@ impl<S: AsyncRead + AsyncWrite + Unpin> futures::AsyncRead for BufferedStream<S>
#[cfg(test)]
mod tests {
use super::*;
use futures::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures::{AsyncRead, AsyncReadExt, AsyncWriteExt};
use tokio::io::DuplexStream;
use tokio_tungstenite::{tungstenite::protocol::Role, WebSocketStream};

Expand All @@ -203,7 +148,6 @@ mod tests {

let bytes_written = stream.write(data).await.unwrap();
assert_eq!(bytes_written, data.len());
assert_eq!(&stream.write_buffer[..], data);
}

#[tokio::test]
Expand Down Expand Up @@ -253,38 +197,6 @@ mod tests {
};
}

#[tokio::test]
async fn test_poisoned_state() {
let (mut stream, server) = create_test_stream().await;
drop(server);

stream.state = State::Poisoned;

let mut buffer = [0u8; 10];
let result = stream.read(&mut buffer).await;
match result {
Err(error) => if error.kind() == std::io::ErrorKind::UnexpectedEof {},
state => panic!("Unexpected state {state:?}"),
};

let mut cx = std::task::Context::from_waker(futures::task::noop_waker_ref());
let mut pin_stream = Pin::new(&mut stream);

// Messages are buffered internally, the socket is not touched.
match pin_stream.as_mut().poll_write(&mut cx, &mut buffer) {
Poll::Ready(Ok(10)) => {}
state => panic!("Unexpected state {state:?}"),
}
// Socket is poisoned, the flush will fail.
match pin_stream.poll_flush(&mut cx) {
Poll::Ready(Err(error)) =>
if error.kind() == std::io::ErrorKind::UnexpectedEof {
return;
},
state => panic!("Unexpected state {state:?}"),
}
}

#[tokio::test]
async fn test_read_poll_pending() {
let (mut stream, mut _server) = create_test_stream().await;
Expand Down
2 changes: 2 additions & 0 deletions tests/connection/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ use crate::common::{add_transport, Transport};

#[cfg(test)]
mod protocol_dial_invalid_address;
#[cfg(test)]
mod stability;

#[tokio::test]
async fn two_litep2ps_work_tcp() {
Expand Down
Loading
Loading