Skip to content

Commit

Permalink
poll_ready pending after WouldBlock errors
Browse files Browse the repository at this point in the history
  • Loading branch information
alexheretic committed Jun 17, 2023
1 parent 36b9d94 commit f61d93d
Showing 1 changed file with 21 additions and 5 deletions.
26 changes: 21 additions & 5 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ pub struct WebSocketStream<S> {
inner: WebSocket<AllowStd<S>>,
closing: bool,
ended: bool,
/// Tungstenite is probably ready to receive more data.
///
/// `false` once start_send hits `WouldBlock` errors.
/// `true` initially and after `flush`ing.
ready: bool,
}

impl<S> WebSocketStream<S> {
Expand Down Expand Up @@ -226,7 +231,7 @@ impl<S> WebSocketStream<S> {
}

pub(crate) fn new(ws: WebSocket<AllowStd<S>>) -> Self {
WebSocketStream { inner: ws, closing: false, ended: false }
Self { inner: ws, closing: false, ended: false, ready: true }
}

fn with_context<F, R>(&mut self, ctx: Option<(ContextWaker, &mut Context<'_>)>, f: F) -> R
Expand Down Expand Up @@ -322,25 +327,35 @@ where
type Error = WsError;

fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
Poll::Ready(Ok(()))
if self.ready {
Poll::Ready(Ok(()))
} else {
Poll::Pending
}
}

fn start_send(mut self: Pin<&mut Self>, item: Message) -> Result<(), Self::Error> {
match (*self).with_context(None, |s| s.write(item)) {
Ok(()) => Ok(()),
Ok(()) => {
self.ready = true;
Ok(())
}
Err(WsError::Io(err)) if err.kind() == std::io::ErrorKind::WouldBlock => {
// the message was accepted and queued
// isn't an error.
// the message was accepted and queued so not an error
// but `poll_ready` will start returning pending now.
self.ready = false;
Ok(())
}
Err(e) => {
self.ready = true;
debug!("websocket start_send error: {}", e);
Err(e)
}
}
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.ready = true;
(*self).with_context(Some((ContextWaker::Write, cx)), |s| cvt(s.flush())).map(|r| {
// WebSocket connection has just been closed. Flushing completed, not an error.
match r {
Expand All @@ -351,6 +366,7 @@ where
}

fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.ready = true;
let res = if self.closing {
// After queueing it, we call `flush` to drive the close handshake to completion.
(*self).with_context(Some((ContextWaker::Write, cx)), |s| s.flush())
Expand Down

0 comments on commit f61d93d

Please sign in to comment.