1
+ use crate :: task:: { Context , Poll } ;
2
+ use futures:: { ready, AsyncWrite , Future , Stream } ;
3
+ use std:: io:: { self , IntoInnerError } ;
4
+ use std:: pin:: Pin ;
5
+ use std:: fmt;
6
+ use crate :: io:: Write ;
7
+
8
+ const DEFAULT_CAPACITY : usize = 8 * 1024 ;
9
+
10
+
11
+ pub struct BufWriter < W : AsyncWrite > {
12
+ inner : Option < W > ,
13
+ buf : Vec < u8 > ,
14
+ panicked : bool ,
15
+ }
16
+
17
+ impl < W : AsyncWrite + Unpin > BufWriter < W > {
18
+ pin_utils:: unsafe_pinned!( inner: Option <W >) ;
19
+ pin_utils:: unsafe_unpinned!( panicked: bool ) ;
20
+
21
+ pub fn new ( inner : W ) -> BufWriter < W > {
22
+ BufWriter :: with_capacity ( DEFAULT_CAPACITY , inner)
23
+ }
24
+
25
+ pub fn with_capacity ( capacity : usize , inner : W ) -> BufWriter < W > {
26
+ BufWriter {
27
+ inner : Some ( inner) ,
28
+ buf : Vec :: with_capacity ( capacity) ,
29
+ panicked : false ,
30
+ }
31
+ }
32
+
33
+ pub fn get_ref ( & self ) -> & W {
34
+ self . inner . as_ref ( ) . unwrap ( )
35
+ }
36
+
37
+ pub fn get_mut ( & mut self ) -> & mut W {
38
+ self . inner . as_mut ( ) . unwrap ( )
39
+ }
40
+
41
+ pub fn buffer ( & self ) -> & [ u8 ] {
42
+ & self . buf
43
+ }
44
+
45
+ pub fn poll_flush_buf ( mut self : Pin < & mut Self > , cx : & mut Context < ' _ > ) -> Poll < io:: Result < ( ) > > {
46
+ let Self {
47
+ inner,
48
+ buf,
49
+ panicked
50
+ } = Pin :: get_mut ( self ) ;
51
+ let mut panicked = Pin :: new ( panicked) ;
52
+ let mut written = 0 ;
53
+ let len = buf. len ( ) ;
54
+ let mut ret = Ok ( ( ) ) ;
55
+ while written < len {
56
+ * panicked = true ;
57
+ let r = Pin :: new ( inner. as_mut ( ) . unwrap ( ) ) ;
58
+ * panicked = false ;
59
+ match r. poll_write ( cx, & buf[ written..] ) {
60
+ Poll :: Ready ( Ok ( 0 ) ) => {
61
+ ret = Err ( io:: Error :: new (
62
+ io:: ErrorKind :: WriteZero ,
63
+ "Failed to write buffered data" ,
64
+ ) ) ;
65
+ break ;
66
+ }
67
+ Poll :: Ready ( Ok ( n) ) => written += n,
68
+ Poll :: Ready ( Err ( ref e) ) if e. kind ( ) == io:: ErrorKind :: Interrupted => { }
69
+ Poll :: Ready ( Err ( e) ) => {
70
+ ret = Err ( e) ;
71
+ break ;
72
+ }
73
+ Poll :: Pending => return Poll :: Pending ,
74
+ }
75
+ }
76
+ if written > 0 {
77
+ buf. drain ( ..written) ;
78
+ }
79
+ Poll :: Ready ( ret)
80
+ }
81
+
82
+ pub fn poll_into_inner (
83
+ mut self : Pin < & mut Self > ,
84
+ cx : & mut Context < ' _ > ,
85
+ //TODO: Fix 'expected function, found struct `IntoInnerError`' compiler error
86
+ ) -> Poll < io:: Result < W > > {
87
+ match ready ! ( self . as_mut( ) . poll_flush_buf( cx) ) {
88
+ Ok ( ( ) ) => Poll :: Ready ( Ok ( self . inner ( ) . take ( ) . unwrap ( ) ) ) ,
89
+ Err ( e) => Poll :: Ready ( Err ( io:: Error :: new ( io:: ErrorKind :: Other , "" ) ) )
90
+ }
91
+ }
92
+ }
93
+
94
+ impl < W : AsyncWrite + Unpin > AsyncWrite for BufWriter < W > {
95
+ fn poll_write (
96
+ mut self : Pin < & mut Self > ,
97
+ cx : & mut Context ,
98
+ buf : & [ u8 ] ,
99
+ ) -> Poll < io:: Result < usize > > {
100
+ let panicked = self . as_mut ( ) . panicked ( ) ;
101
+ if self . as_ref ( ) . buf . len ( ) + buf. len ( ) > self . as_ref ( ) . buf . capacity ( ) {
102
+ match ready ! ( self . as_mut( ) . poll_flush_buf( cx) ) {
103
+ Ok ( ( ) ) => { } ,
104
+ Err ( e) => return Poll :: Ready ( Err ( e) )
105
+ }
106
+ }
107
+ if buf. len ( ) >= self . as_ref ( ) . buf . capacity ( ) {
108
+ * panicked = true ;
109
+ let r = ready ! ( self . as_mut( ) . poll_write( cx, buf) ) ;
110
+ * panicked = false ;
111
+ return Poll :: Ready ( r)
112
+ } else {
113
+ return Poll :: Ready ( ready ! ( self . as_ref( ) . buf. write( buf) . poll( ) ) )
114
+ }
115
+ }
116
+
117
+ fn poll_flush ( mut self : Pin < & mut Self > , cx : & mut Context ) -> Poll < io:: Result < ( ) > > {
118
+ unimplemented ! ( )
119
+ }
120
+
121
+ fn poll_close ( mut self : Pin < & mut Self > , cx : & mut Context ) -> Poll < io:: Result < ( ) > > {
122
+ unimplemented ! ( )
123
+ }
124
+ }
125
+
126
+ impl < W : AsyncWrite + fmt:: Debug > fmt:: Debug for BufWriter < W > {
127
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
128
+ f. debug_struct ( "BufReader" )
129
+ . field ( "writer" , & self . inner )
130
+ . field (
131
+ "buf" ,
132
+ & self . buf
133
+ )
134
+ . finish ( )
135
+ }
136
+ }
137
+
138
+ mod tests {
139
+
140
+ }
0 commit comments