Skip to content

Commit 5241a19

Browse files
nanoqshsdroege
authored andcommitted
Add closing state for ByteWriter
1 parent e05133a commit 5241a19

File tree

2 files changed

+77
-20
lines changed

2 files changed

+77
-20
lines changed

src/bytes.rs

Lines changed: 62 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
use std::{
55
io,
66
pin::Pin,
7-
task::{ready, Context, Poll},
7+
task::{Context, Poll},
88
};
99

1010
use futures_core::stream::Stream;
@@ -16,22 +16,50 @@ use crate::{tungstenite::Bytes, Message, WsError};
1616
/// Every write sends a binary message. If you want to group writes together, consider wrapping
1717
/// this with a `BufWriter`.
1818
#[derive(Debug)]
19-
pub struct ByteWriter<S>(S);
19+
pub struct ByteWriter<S> {
20+
sender: S,
21+
state: State,
22+
}
2023

2124
impl<S> ByteWriter<S> {
2225
/// Create a new `ByteWriter` from a [sender](Sender) that accepts a websocket [`Message`].
2326
#[inline(always)]
24-
pub fn new(s: S) -> Self
27+
pub fn new(sender: S) -> Self
2528
where
2629
S: Sender,
2730
{
28-
Self(s)
31+
Self {
32+
sender,
33+
state: State::Open,
34+
}
2935
}
3036

3137
/// Get the underlying [sender](Sender) back.
3238
#[inline(always)]
3339
pub fn into_inner(self) -> S {
34-
self.0
40+
self.sender
41+
}
42+
}
43+
44+
#[derive(Debug)]
45+
enum State {
46+
Open,
47+
Closing(Option<Message>),
48+
}
49+
50+
impl State {
51+
fn close(&mut self) -> &mut Option<Message> {
52+
match self {
53+
State::Open => {
54+
*self = State::Closing(Some(Message::Close(None)));
55+
if let State::Closing(msg) = self {
56+
msg
57+
} else {
58+
unreachable!()
59+
}
60+
}
61+
State::Closing(msg) => msg,
62+
}
3563
}
3664
}
3765

@@ -55,7 +83,12 @@ pub(crate) mod private {
5583
) -> Poll<Result<usize, WsError>>;
5684

5785
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>>;
58-
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>>;
86+
87+
fn poll_close(
88+
self: Pin<&mut Self>,
89+
cx: &mut Context<'_>,
90+
msg: &mut Option<Message>,
91+
) -> Poll<Result<(), WsError>>;
5992
}
6093

6194
impl<S> Sender for S where S: SealedSender {}
@@ -71,6 +104,8 @@ where
71104
cx: &mut Context<'_>,
72105
buf: &[u8],
73106
) -> Poll<Result<usize, WsError>> {
107+
use std::task::ready;
108+
74109
ready!(self.as_mut().poll_ready(cx))?;
75110
let len = buf.len();
76111
self.start_send(Message::binary(buf.to_owned()))?;
@@ -81,7 +116,11 @@ where
81116
<S as futures_util::Sink<_>>::poll_flush(self, cx)
82117
}
83118

84-
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
119+
fn poll_close(
120+
self: Pin<&mut Self>,
121+
cx: &mut Context<'_>,
122+
_: &mut Option<Message>,
123+
) -> Poll<Result<(), WsError>> {
85124
<S as futures_util::Sink<_>>::poll_close(self, cx)
86125
}
87126
}
@@ -95,16 +134,20 @@ where
95134
cx: &mut Context<'_>,
96135
buf: &[u8],
97136
) -> Poll<io::Result<usize>> {
98-
<S as private::SealedSender>::poll_write(Pin::new(&mut self.0), cx, buf)
137+
<S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
99138
.map_err(convert_err)
100139
}
101140

102141
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
103-
<S as private::SealedSender>::poll_flush(Pin::new(&mut self.0), cx).map_err(convert_err)
142+
<S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
143+
.map_err(convert_err)
104144
}
105145

106-
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107-
<S as private::SealedSender>::poll_close(Pin::new(&mut self.0), cx).map_err(convert_err)
146+
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
147+
let me = self.get_mut();
148+
let msg = me.state.close();
149+
<S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
150+
.map_err(convert_err)
108151
}
109152
}
110153

@@ -118,16 +161,20 @@ where
118161
cx: &mut Context<'_>,
119162
buf: &[u8],
120163
) -> Poll<io::Result<usize>> {
121-
<S as private::SealedSender>::poll_write(Pin::new(&mut self.0), cx, buf)
164+
<S as private::SealedSender>::poll_write(Pin::new(&mut self.sender), cx, buf)
122165
.map_err(convert_err)
123166
}
124167

125168
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
126-
<S as private::SealedSender>::poll_flush(Pin::new(&mut self.0), cx).map_err(convert_err)
169+
<S as private::SealedSender>::poll_flush(Pin::new(&mut self.sender), cx)
170+
.map_err(convert_err)
127171
}
128172

129-
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130-
<S as private::SealedSender>::poll_close(Pin::new(&mut self.0), cx).map_err(convert_err)
173+
fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
174+
let me = self.get_mut();
175+
let msg = me.state.close();
176+
<S as private::SealedSender>::poll_close(Pin::new(&mut me.sender), cx, msg)
177+
.map_err(convert_err)
131178
}
132179
}
133180

src/lib.rs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ pub mod tokio;
9595

9696
pub mod bytes;
9797
pub use bytes::ByteReader;
98-
#[cfg(feature = "futures-03-sink")]
9998
pub use bytes::ByteWriter;
10099

101100
use tungstenite::protocol::CloseFrame;
@@ -530,8 +529,13 @@ where
530529
self.get_mut().poll_flush(cx)
531530
}
532531

533-
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
534-
self.get_mut().poll_close(cx)
532+
fn poll_close(
533+
self: Pin<&mut Self>,
534+
cx: &mut Context<'_>,
535+
msg: &mut Option<Message>,
536+
) -> Poll<Result<(), WsError>> {
537+
let me = self.get_mut();
538+
send_helper(me, msg, cx)
535539
}
536540
}
537541

@@ -677,8 +681,14 @@ where
677681
self.shared.lock().poll_flush(cx)
678682
}
679683

680-
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), WsError>> {
681-
self.shared.lock().poll_close(cx)
684+
fn poll_close(
685+
self: Pin<&mut Self>,
686+
cx: &mut Context<'_>,
687+
msg: &mut Option<Message>,
688+
) -> Poll<Result<(), WsError>> {
689+
let me = self.get_mut();
690+
let mut ws = me.shared.lock();
691+
send_helper(&mut ws, msg, cx)
682692
}
683693
}
684694

0 commit comments

Comments
 (0)