diff --git a/tokio-stream/src/wrappers/broadcast.rs b/tokio-stream/src/wrappers/broadcast.rs index fbfa9a4c78d..536e179bd6f 100644 --- a/tokio-stream/src/wrappers/broadcast.rs +++ b/tokio-stream/src/wrappers/broadcast.rs @@ -1,37 +1,61 @@ use crate::Stream; -use async_stream::try_stream; +use async_stream::stream; use std::pin::Pin; use tokio::sync::broadcast::error::RecvError; use tokio::sync::broadcast::Receiver; +use std::fmt; use std::task::{Context, Poll}; /// A wrapper around [`Receiver`] that implements [`Stream`]. /// /// [`Receiver`]: struct@tokio::sync::broadcast::Receiver /// [`Stream`]: trait@crate::Stream -#[derive(Debug)] pub struct BroadcastStream { - inner: Pin> + Send + Sync >>, + inner: Pin> + Send + Sync>>, +} + +/// An error returned from the inner stream of a [`BroadcastStream`]. +#[derive(Debug, PartialEq)] +pub enum BroadcastStreamRecvError { + /// The receiver lagged too far behind. Attempting to receive again will + /// return the oldest message still retained by the channel. + /// + /// Includes the number of skipped messages. + Lagged(u64), } impl BroadcastStream { /// Create a new `BroadcastStream`. pub fn new(mut rx: Receiver) -> Self { - let stream = try_stream! { + let stream = stream! { loop { - let item = rx.recv().await?; - yield item; + match rx.recv().await { + Ok(item) => yield Ok(item), + Err(err) => + match err { + RecvError::Closed => break, + RecvError::Lagged(n) => yield Err(BroadcastStreamRecvError::Lagged(n)) + } + } } }; - Self { inner: Box::pin(stream) } + Self { + inner: Box::pin(stream), + } } } impl Stream for BroadcastStream { - type Item = Result; + type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut self.inner).poll_next(cx) } } + +impl fmt::Debug for BroadcastStream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("BroadcastStream").finish() + } +}