11use  std:: { 
2-     collections:: HashSet , 
2+     collections:: { HashMap ,   HashSet } , 
33    io, 
44    mem:: ManuallyDrop , 
5-     os:: windows:: prelude:: { 
6-         AsRawHandle ,  AsRawSocket ,  FromRawHandle ,  FromRawSocket ,  IntoRawHandle ,  IntoRawSocket , 
7-         RawHandle , 
5+     os:: { 
6+         raw:: c_void, 
7+         windows:: prelude:: { 
8+             AsRawHandle ,  AsRawSocket ,  FromRawHandle ,  FromRawSocket ,  IntoRawHandle ,  IntoRawSocket , 
9+             RawHandle , 
10+         } , 
811    } , 
912    pin:: Pin , 
10-     ptr:: NonNull , 
13+     ptr:: { null ,   NonNull } , 
1114    sync:: Arc , 
1215    task:: Poll , 
1316    time:: Duration , 
@@ -17,9 +20,15 @@ use compio_buf::BufResult;
1720use  compio_log:: { instrument,  trace} ; 
1821use  slab:: Slab ; 
1922use  windows_sys:: Win32 :: { 
20-     Foundation :: { ERROR_BUSY ,  ERROR_OPERATION_ABORTED } , 
23+     Foundation :: { ERROR_BUSY ,  ERROR_OPERATION_ABORTED ,   ERROR_TIMEOUT ,   WAIT_OBJECT_0 ,   WAIT_TIMEOUT } , 
2124    Networking :: WinSock :: { WSACleanup ,  WSAStartup ,  WSADATA } , 
22-     System :: IO :: OVERLAPPED , 
25+     System :: { 
26+         Threading :: { 
27+             CloseThreadpoolWait ,  CreateThreadpoolWait ,  SetThreadpoolWait , 
28+             WaitForThreadpoolWaitCallbacks ,  PTP_CALLBACK_INSTANCE ,  PTP_WAIT , 
29+         } , 
30+         IO :: OVERLAPPED , 
31+     } , 
2332} ; 
2433
2534use  crate :: { syscall,  AsyncifyPool ,  Entry ,  OutEntries ,  ProactorBuilder } ; 
@@ -98,12 +107,25 @@ impl IntoRawFd for socket2::Socket {
98107    } 
99108} 
100109
110+ /// Operation type. 
111+ pub  enum  OpType  { 
112+     /// An overlapped operation. 
113+      Overlapped , 
114+     /// A blocking operation, needs a thread to spawn. The `operate` method 
115+      /// should be thread safe. 
116+      Blocking , 
117+     /// A Win32 event object to be waited. The user should ensure that the 
118+      /// handle is valid till operation completes. The `operate` method should be 
119+      /// thread safe. 
120+      Event ( RawFd ) , 
121+ } 
122+ 
101123/// Abstraction of IOCP operations. 
102124pub  trait  OpCode  { 
103125    /// Determines that the operation is really overlapped defined by Windows 
104126     /// API. If not, the driver will try to operate it in another thread. 
105-      fn  is_overlapped ( & self )  -> bool  { 
106-         true 
127+      fn  op_type ( & self )  -> OpType  { 
128+         OpType :: Overlapped 
107129    } 
108130
109131    /// Perform Windows API call with given pointer to overlapped struct. 
@@ -133,6 +155,7 @@ pub trait OpCode {
133155/// Low-level driver of IOCP. 
134156pub ( crate )  struct  Driver  { 
135157    port :  cp:: Port , 
158+     waits :  HashMap < usize ,  WinThreadpollWait > , 
136159    cancelled :  HashSet < usize > , 
137160    pool :  AsyncifyPool , 
138161    notify_overlapped :  Arc < Overlapped < ( ) > > , 
@@ -150,6 +173,7 @@ impl Driver {
150173        let  driver = port. as_raw_handle ( )  as  _ ; 
151174        Ok ( Self  { 
152175            port, 
176+             waits :  HashMap :: default ( ) , 
153177            cancelled :  HashSet :: default ( ) , 
154178            pool :  builder. create_or_get_thread_pool ( ) , 
155179            notify_overlapped :  Arc :: new ( Overlapped :: new ( driver,  Self :: NOTIFY ,  ( ) ) ) , 
@@ -188,12 +212,22 @@ impl Driver {
188212            trace ! ( "push RawOp" ) ; 
189213            let  optr = op. as_mut_ptr ( ) ; 
190214            let  op_pin = op. as_op_pin ( ) ; 
191-             if  op_pin. is_overlapped ( )  { 
192-                 unsafe  {  op_pin. operate ( optr. cast ( ) )  } 
193-             }  else  if  self . push_blocking ( op) ? { 
194-                 Poll :: Pending 
195-             }  else  { 
196-                 Poll :: Ready ( Err ( io:: Error :: from_raw_os_error ( ERROR_BUSY  as  _ ) ) ) 
215+             match  op_pin. op_type ( )  { 
216+                 OpType :: Overlapped  => unsafe  {  op_pin. operate ( optr. cast ( ) )  } , 
217+                 OpType :: Blocking  => { 
218+                     if  self . push_blocking ( op) ? { 
219+                         Poll :: Pending 
220+                     }  else  { 
221+                         Poll :: Ready ( Err ( io:: Error :: from_raw_os_error ( ERROR_BUSY  as  _ ) ) ) 
222+                     } 
223+                 } 
224+                 OpType :: Event ( e)  => { 
225+                     self . waits . insert ( 
226+                         user_data, 
227+                         WinThreadpollWait :: new ( self . port . handle ( ) ,  e,  op) ?, 
228+                     ) ; 
229+                     Poll :: Pending 
230+                 } 
197231            } 
198232        } 
199233    } 
@@ -213,20 +247,20 @@ impl Driver {
213247                // Safety: the pointer is created from a reference. 
214248                let  op = unsafe  {  optr. 0 . as_mut ( )  } ; 
215249                let  optr = op. as_mut_ptr ( ) ; 
216-                 let  op = op. as_op_pin ( ) ; 
217-                 let  res = unsafe  {  op. operate ( optr. cast ( ) )  } ; 
218-                 let  res = match  res { 
219-                     Poll :: Pending  => unreachable ! ( "this operation is not overlapped" ) , 
220-                     Poll :: Ready ( res)  => res, 
221-                 } ; 
250+                 let  res = op. operate_blocking ( ) ; 
222251                port. post ( res,  optr) . ok ( ) ; 
223252            } ) 
224253            . is_ok ( ) ) 
225254    } 
226255
227-     fn  create_entry ( cancelled :  & mut  HashSet < usize > ,  entry :  Entry )  -> Option < Entry >  { 
256+     fn  create_entry ( 
257+         cancelled :  & mut  HashSet < usize > , 
258+         waits :  & mut  HashMap < usize ,  WinThreadpollWait > , 
259+         entry :  Entry , 
260+     )  -> Option < Entry >  { 
228261        let  user_data = entry. user_data ( ) ; 
229262        if  user_data != Self :: NOTIFY  { 
263+             waits. remove ( & user_data) ; 
230264            let  result = if  cancelled. remove ( & user_data)  { 
231265                Err ( io:: Error :: from_raw_os_error ( ERROR_OPERATION_ABORTED  as  _ ) ) 
232266            }  else  { 
@@ -248,7 +282,7 @@ impl Driver {
248282        entries. extend ( 
249283            self . port 
250284                . poll ( timeout) ?
251-                 . filter_map ( |e| Self :: create_entry ( & mut  self . cancelled ,  e) ) , 
285+                 . filter_map ( |e| Self :: create_entry ( & mut  self . cancelled ,  & mut   self . waits ,   e) ) , 
252286        ) ; 
253287
254288        Ok ( ( ) ) 
@@ -291,6 +325,67 @@ impl NotifyHandle {
291325    } 
292326} 
293327
328+ struct  WinThreadpollWait  { 
329+     wait :  PTP_WAIT , 
330+     // For memory safety. 
331+     #[ allow( dead_code) ]  
332+     context :  Box < WinThreadpollWaitContext > , 
333+ } 
334+ 
335+ impl  WinThreadpollWait  { 
336+     pub  fn  new ( port :  cp:: PortHandle ,  event :  RawFd ,  op :  & mut  RawOp )  -> io:: Result < Self >  { 
337+         let  mut  context = Box :: new ( WinThreadpollWaitContext  {  port,  op } ) ; 
338+         let  wait = syscall ! ( 
339+             BOOL , 
340+             CreateThreadpoolWait ( 
341+                 Some ( Self :: wait_callback) , 
342+                 ( & mut  * context)  as  * mut  WinThreadpollWaitContext  as  _, 
343+                 null( ) 
344+             ) 
345+         ) ?; 
346+         unsafe  { 
347+             SetThreadpoolWait ( wait,  event as  _ ,  null ( ) ) ; 
348+         } 
349+         Ok ( Self  {  wait,  context } ) 
350+     } 
351+ 
352+     unsafe  extern  "system"  fn  wait_callback ( 
353+         _instance :  PTP_CALLBACK_INSTANCE , 
354+         context :  * mut  c_void , 
355+         _wait :  PTP_WAIT , 
356+         result :  u32 , 
357+     )  { 
358+         let  context = & * ( context as  * mut  WinThreadpollWaitContext ) ; 
359+         let  res = match  result { 
360+             WAIT_OBJECT_0  => Ok ( 0 ) , 
361+             WAIT_TIMEOUT  => Err ( io:: Error :: from_raw_os_error ( ERROR_TIMEOUT  as  _ ) ) , 
362+             _ => Err ( io:: Error :: from_raw_os_error ( result as  _ ) ) , 
363+         } ; 
364+         let  res = if  res. is_err ( )  { 
365+             res
366+         }  else  { 
367+             let  op = unsafe  {  & mut  * context. op  } ; 
368+             op. operate_blocking ( ) 
369+         } ; 
370+         context. port . post ( res,  ( * context. op ) . as_mut_ptr ( ) ) . ok ( ) ; 
371+     } 
372+ } 
373+ 
374+ impl  Drop  for  WinThreadpollWait  { 
375+     fn  drop ( & mut  self )  { 
376+         unsafe  { 
377+             SetThreadpoolWait ( self . wait ,  0 ,  null ( ) ) ; 
378+             WaitForThreadpoolWaitCallbacks ( self . wait ,  1 ) ; 
379+             CloseThreadpoolWait ( self . wait ) ; 
380+         } 
381+     } 
382+ } 
383+ 
384+ struct  WinThreadpollWaitContext  { 
385+     port :  cp:: PortHandle , 
386+     op :  * mut  RawOp , 
387+ } 
388+ 
294389/// The overlapped struct we actually used for IOCP. 
295390#[ repr( C ) ]  
296391pub  struct  Overlapped < T :  ?Sized >  { 
@@ -371,6 +466,16 @@ impl RawOp {
371466        let  overlapped:  Box < Overlapped < T > >  = Box :: from_raw ( this. op . cast ( ) . as_ptr ( ) ) ; 
372467        BufResult ( this. result . take ( ) . unwrap ( ) ,  overlapped. op ) 
373468    } 
469+ 
470+     fn  operate_blocking ( & mut  self )  -> io:: Result < usize >  { 
471+         let  optr = self . as_mut_ptr ( ) ; 
472+         let  op = self . as_op_pin ( ) ; 
473+         let  res = unsafe  {  op. operate ( optr. cast ( ) )  } ; 
474+         match  res { 
475+             Poll :: Pending  => unreachable ! ( "this operation is not overlapped" ) , 
476+             Poll :: Ready ( res)  => res, 
477+         } 
478+     } 
374479} 
375480
376481impl  Drop  for  RawOp  { 
0 commit comments