44use std:: {
55 io,
66 pin:: Pin ,
7- task:: { ready , Context , Poll } ,
7+ task:: { Context , Poll } ,
88} ;
99
1010use futures_core:: stream:: Stream ;
@@ -16,22 +16,50 @@ use crate::{tungstenite::Bytes, Message, WsError};
1616/// Every write sends a binary message. If you want to group writes together, consider wrapping
1717/// this with a `BufWriter`.
1818#[ derive( Debug ) ]
19- pub struct ByteWriter < S > ( S ) ;
19+ pub struct ByteWriter < S > {
20+ sender : S ,
21+ state : State ,
22+ }
2023
2124impl < S > ByteWriter < S > {
2225 /// Create a new `ByteWriter` from a [sender](Sender) that accepts a websocket [`Message`].
2326 #[ inline( always) ]
24- pub fn new ( s : S ) -> Self
27+ pub fn new ( sender : S ) -> Self
2528 where
2629 S : Sender ,
2730 {
28- Self ( s)
31+ Self {
32+ sender,
33+ state : State :: Open ,
34+ }
2935 }
3036
3137 /// Get the underlying [sender](Sender) back.
3238 #[ inline( always) ]
3339 pub fn into_inner ( self ) -> S {
34- self . 0
40+ self . sender
41+ }
42+ }
43+
44+ #[ derive( Debug ) ]
45+ enum State {
46+ Open ,
47+ Closing ( Option < Message > ) ,
48+ }
49+
50+ impl State {
51+ fn close ( & mut self ) -> & mut Option < Message > {
52+ match self {
53+ State :: Open => {
54+ * self = State :: Closing ( Some ( Message :: Close ( None ) ) ) ;
55+ if let State :: Closing ( msg) = self {
56+ msg
57+ } else {
58+ unreachable ! ( )
59+ }
60+ }
61+ State :: Closing ( msg) => msg,
62+ }
3563 }
3664}
3765
@@ -55,7 +83,12 @@ pub(crate) mod private {
5583 ) -> Poll < Result < usize , WsError > > ;
5684
5785 fn poll_flush ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , WsError > > ;
58- fn poll_close ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , WsError > > ;
86+
87+ fn poll_close (
88+ self : Pin < & mut Self > ,
89+ cx : & mut Context < ' _ > ,
90+ msg : & mut Option < Message > ,
91+ ) -> Poll < Result < ( ) , WsError > > ;
5992 }
6093
6194 impl < S > Sender for S where S : SealedSender { }
71104 cx : & mut Context < ' _ > ,
72105 buf : & [ u8 ] ,
73106 ) -> Poll < Result < usize , WsError > > {
107+ use std:: task:: ready;
108+
74109 ready ! ( self . as_mut( ) . poll_ready( cx) ) ?;
75110 let len = buf. len ( ) ;
76111 self . start_send ( Message :: binary ( buf. to_owned ( ) ) ) ?;
@@ -81,7 +116,11 @@ where
81116 <S as futures_util:: Sink < _ > >:: poll_flush ( self , cx)
82117 }
83118
84- fn poll_close ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < Result < ( ) , WsError > > {
119+ fn poll_close (
120+ self : Pin < & mut Self > ,
121+ cx : & mut Context < ' _ > ,
122+ _: & mut Option < Message > ,
123+ ) -> Poll < Result < ( ) , WsError > > {
85124 <S as futures_util:: Sink < _ > >:: poll_close ( self , cx)
86125 }
87126}
@@ -95,16 +134,20 @@ where
95134 cx : & mut Context < ' _ > ,
96135 buf : & [ u8 ] ,
97136 ) -> Poll < io:: Result < usize > > {
98- <S as private:: SealedSender >:: poll_write ( Pin :: new ( & mut self . 0 ) , cx, buf)
137+ <S as private:: SealedSender >:: poll_write ( Pin :: new ( & mut self . sender ) , cx, buf)
99138 . map_err ( convert_err)
100139 }
101140
102141 fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
103- <S as private:: SealedSender >:: poll_flush ( Pin :: new ( & mut self . 0 ) , cx) . map_err ( convert_err)
142+ <S as private:: SealedSender >:: poll_flush ( Pin :: new ( & mut self . sender ) , cx)
143+ . map_err ( convert_err)
104144 }
105145
106- fn poll_close ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
107- <S as private:: SealedSender >:: poll_close ( Pin :: new ( & mut self . 0 ) , cx) . map_err ( convert_err)
146+ fn poll_close ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
147+ let me = self . get_mut ( ) ;
148+ let msg = me. state . close ( ) ;
149+ <S as private:: SealedSender >:: poll_close ( Pin :: new ( & mut me. sender ) , cx, msg)
150+ . map_err ( convert_err)
108151 }
109152}
110153
@@ -118,16 +161,20 @@ where
118161 cx : & mut Context < ' _ > ,
119162 buf : & [ u8 ] ,
120163 ) -> Poll < io:: Result < usize > > {
121- <S as private:: SealedSender >:: poll_write ( Pin :: new ( & mut self . 0 ) , cx, buf)
164+ <S as private:: SealedSender >:: poll_write ( Pin :: new ( & mut self . sender ) , cx, buf)
122165 . map_err ( convert_err)
123166 }
124167
125168 fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
126- <S as private:: SealedSender >:: poll_flush ( Pin :: new ( & mut self . 0 ) , cx) . map_err ( convert_err)
169+ <S as private:: SealedSender >:: poll_flush ( Pin :: new ( & mut self . sender ) , cx)
170+ . map_err ( convert_err)
127171 }
128172
129- fn poll_shutdown ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
130- <S as private:: SealedSender >:: poll_close ( Pin :: new ( & mut self . 0 ) , cx) . map_err ( convert_err)
173+ fn poll_shutdown ( self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
174+ let me = self . get_mut ( ) ;
175+ let msg = me. state . close ( ) ;
176+ <S as private:: SealedSender >:: poll_close ( Pin :: new ( & mut me. sender ) , cx, msg)
177+ . map_err ( convert_err)
131178 }
132179}
133180
0 commit comments