Skip to content

Commit 338a762

Browse files
committed
feat(driver,iocp): add win32 event support by thread pool API
1 parent 4631553 commit 338a762

File tree

4 files changed

+200
-45
lines changed

4 files changed

+200
-45
lines changed

compio-driver/src/iocp/mod.rs

Lines changed: 128 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
use std::{
2-
collections::HashSet,
2+
collections::{HashMap, HashSet},
33
io,
44
mem::ManuallyDrop,
5-
os::windows::prelude::{
6-
AsRawHandle, AsRawSocket, FromRawHandle, FromRawSocket, IntoRawHandle, IntoRawSocket,
7-
RawHandle,
5+
os::{
6+
raw::c_void,
7+
windows::prelude::{
8+
AsRawHandle, AsRawSocket, FromRawHandle, FromRawSocket, IntoRawHandle, IntoRawSocket,
9+
RawHandle,
10+
},
811
},
912
pin::Pin,
10-
ptr::NonNull,
13+
ptr::{null, NonNull},
1114
sync::Arc,
1215
task::Poll,
1316
time::Duration,
@@ -17,9 +20,15 @@ use compio_buf::BufResult;
1720
use compio_log::{instrument, trace};
1821
use slab::Slab;
1922
use windows_sys::Win32::{
20-
Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED},
23+
Foundation::{ERROR_BUSY, ERROR_OPERATION_ABORTED, ERROR_TIMEOUT, WAIT_OBJECT_0, WAIT_TIMEOUT},
2124
Networking::WinSock::{WSACleanup, WSAStartup, WSADATA},
22-
System::IO::OVERLAPPED,
25+
System::{
26+
Threading::{
27+
CloseThreadpoolWait, CreateThreadpoolWait, SetThreadpoolWait,
28+
WaitForThreadpoolWaitCallbacks, PTP_CALLBACK_INSTANCE, PTP_WAIT,
29+
},
30+
IO::OVERLAPPED,
31+
},
2332
};
2433

2534
use crate::{syscall, AsyncifyPool, Entry, OutEntries, ProactorBuilder};
@@ -98,12 +107,25 @@ impl IntoRawFd for socket2::Socket {
98107
}
99108
}
100109

110+
/// Operation type.
111+
pub enum OpType {
112+
/// An overlapped operation.
113+
Overlapped,
114+
/// A blocking operation, needs a thread to spawn. The `operate` method
115+
/// should be thread safe.
116+
Blocking,
117+
/// A Win32 event object to be waited. The user should ensure that the
118+
/// handle is valid till operation completes. The `operate` method should be
119+
/// thread safe.
120+
Event(RawFd),
121+
}
122+
101123
/// Abstraction of IOCP operations.
102124
pub trait OpCode {
103125
/// Determines that the operation is really overlapped defined by Windows
104126
/// API. If not, the driver will try to operate it in another thread.
105-
fn is_overlapped(&self) -> bool {
106-
true
127+
fn op_type(&self) -> OpType {
128+
OpType::Overlapped
107129
}
108130

109131
/// Perform Windows API call with given pointer to overlapped struct.
@@ -133,6 +155,7 @@ pub trait OpCode {
133155
/// Low-level driver of IOCP.
134156
pub(crate) struct Driver {
135157
port: cp::Port,
158+
waits: HashMap<usize, WinThreadpollWait>,
136159
cancelled: HashSet<usize>,
137160
pool: AsyncifyPool,
138161
notify_overlapped: Arc<Overlapped<()>>,
@@ -150,6 +173,7 @@ impl Driver {
150173
let driver = port.as_raw_handle() as _;
151174
Ok(Self {
152175
port,
176+
waits: HashMap::default(),
153177
cancelled: HashSet::default(),
154178
pool: builder.create_or_get_thread_pool(),
155179
notify_overlapped: Arc::new(Overlapped::new(driver, Self::NOTIFY, ())),
@@ -188,12 +212,22 @@ impl Driver {
188212
trace!("push RawOp");
189213
let optr = op.as_mut_ptr();
190214
let op_pin = op.as_op_pin();
191-
if op_pin.is_overlapped() {
192-
unsafe { op_pin.operate(optr.cast()) }
193-
} else if self.push_blocking(op)? {
194-
Poll::Pending
195-
} else {
196-
Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _)))
215+
match op_pin.op_type() {
216+
OpType::Overlapped => unsafe { op_pin.operate(optr.cast()) },
217+
OpType::Blocking => {
218+
if self.push_blocking(op)? {
219+
Poll::Pending
220+
} else {
221+
Poll::Ready(Err(io::Error::from_raw_os_error(ERROR_BUSY as _)))
222+
}
223+
}
224+
OpType::Event(e) => {
225+
self.waits.insert(
226+
user_data,
227+
WinThreadpollWait::new(self.port.handle(), e, op)?,
228+
);
229+
Poll::Pending
230+
}
197231
}
198232
}
199233
}
@@ -213,20 +247,20 @@ impl Driver {
213247
// Safety: the pointer is created from a reference.
214248
let op = unsafe { optr.0.as_mut() };
215249
let optr = op.as_mut_ptr();
216-
let op = op.as_op_pin();
217-
let res = unsafe { op.operate(optr.cast()) };
218-
let res = match res {
219-
Poll::Pending => unreachable!("this operation is not overlapped"),
220-
Poll::Ready(res) => res,
221-
};
250+
let res = op.operate_blocking();
222251
port.post(res, optr).ok();
223252
})
224253
.is_ok())
225254
}
226255

227-
fn create_entry(cancelled: &mut HashSet<usize>, entry: Entry) -> Option<Entry> {
256+
fn create_entry(
257+
cancelled: &mut HashSet<usize>,
258+
waits: &mut HashMap<usize, WinThreadpollWait>,
259+
entry: Entry,
260+
) -> Option<Entry> {
228261
let user_data = entry.user_data();
229262
if user_data != Self::NOTIFY {
263+
waits.remove(&user_data);
230264
let result = if cancelled.remove(&user_data) {
231265
Err(io::Error::from_raw_os_error(ERROR_OPERATION_ABORTED as _))
232266
} else {
@@ -248,7 +282,7 @@ impl Driver {
248282
entries.extend(
249283
self.port
250284
.poll(timeout)?
251-
.filter_map(|e| Self::create_entry(&mut self.cancelled, e)),
285+
.filter_map(|e| Self::create_entry(&mut self.cancelled, &mut self.waits, e)),
252286
);
253287

254288
Ok(())
@@ -291,6 +325,67 @@ impl NotifyHandle {
291325
}
292326
}
293327

328+
struct WinThreadpollWait {
329+
wait: PTP_WAIT,
330+
// For memory safety.
331+
#[allow(dead_code)]
332+
context: Box<WinThreadpollWaitContext>,
333+
}
334+
335+
impl WinThreadpollWait {
336+
pub fn new(port: cp::PortHandle, event: RawFd, op: &mut RawOp) -> io::Result<Self> {
337+
let mut context = Box::new(WinThreadpollWaitContext { port, op });
338+
let wait = syscall!(
339+
BOOL,
340+
CreateThreadpoolWait(
341+
Some(Self::wait_callback),
342+
(&mut *context) as *mut WinThreadpollWaitContext as _,
343+
null()
344+
)
345+
)?;
346+
unsafe {
347+
SetThreadpoolWait(wait, event as _, null());
348+
}
349+
Ok(Self { wait, context })
350+
}
351+
352+
unsafe extern "system" fn wait_callback(
353+
_instance: PTP_CALLBACK_INSTANCE,
354+
context: *mut c_void,
355+
_wait: PTP_WAIT,
356+
result: u32,
357+
) {
358+
let context = &*(context as *mut WinThreadpollWaitContext);
359+
let res = match result {
360+
WAIT_OBJECT_0 => Ok(0),
361+
WAIT_TIMEOUT => Err(io::Error::from_raw_os_error(ERROR_TIMEOUT as _)),
362+
_ => Err(io::Error::from_raw_os_error(result as _)),
363+
};
364+
let res = if res.is_err() {
365+
res
366+
} else {
367+
let op = unsafe { &mut *context.op };
368+
op.operate_blocking()
369+
};
370+
context.port.post(res, (*context.op).as_mut_ptr()).ok();
371+
}
372+
}
373+
374+
impl Drop for WinThreadpollWait {
375+
fn drop(&mut self) {
376+
unsafe {
377+
SetThreadpoolWait(self.wait, 0, null());
378+
WaitForThreadpoolWaitCallbacks(self.wait, 1);
379+
CloseThreadpoolWait(self.wait);
380+
}
381+
}
382+
}
383+
384+
struct WinThreadpollWaitContext {
385+
port: cp::PortHandle,
386+
op: *mut RawOp,
387+
}
388+
294389
/// The overlapped struct we actually used for IOCP.
295390
#[repr(C)]
296391
pub struct Overlapped<T: ?Sized> {
@@ -371,6 +466,16 @@ impl RawOp {
371466
let overlapped: Box<Overlapped<T>> = Box::from_raw(this.op.cast().as_ptr());
372467
BufResult(this.result.take().unwrap(), overlapped.op)
373468
}
469+
470+
fn operate_blocking(&mut self) -> io::Result<usize> {
471+
let optr = self.as_mut_ptr();
472+
let op = self.as_op_pin();
473+
let res = unsafe { op.operate(optr.cast()) };
474+
match res {
475+
Poll::Pending => unreachable!("this operation is not overlapped"),
476+
Poll::Ready(res) => res,
477+
}
478+
}
374479
}
375480

376481
impl Drop for RawOp {

compio-driver/src/iocp/op.rs

Lines changed: 17 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ use windows_sys::{
4747
},
4848
};
4949

50-
use crate::{op::*, syscall, OpCode, RawFd};
50+
use crate::{op::*, syscall, OpCode, OpType, RawFd};
5151

5252
#[inline]
5353
fn winapi_result(transferred: u32) -> Poll<io::Result<usize>> {
@@ -119,8 +119,8 @@ impl<
119119
F: (FnOnce() -> BufResult<usize, D>) + std::marker::Send + std::marker::Sync + 'static,
120120
> OpCode for Asyncify<F, D>
121121
{
122-
fn is_overlapped(&self) -> bool {
123-
false
122+
fn op_type(&self) -> OpType {
123+
OpType::Blocking
124124
}
125125

126126
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -176,8 +176,8 @@ impl OpenFile {
176176
}
177177

178178
impl OpCode for OpenFile {
179-
fn is_overlapped(&self) -> bool {
180-
false
179+
fn op_type(&self) -> OpType {
180+
OpType::Blocking
181181
}
182182

183183
unsafe fn operate(mut self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -200,8 +200,8 @@ impl OpCode for OpenFile {
200200
}
201201

202202
impl OpCode for CloseFile {
203-
fn is_overlapped(&self) -> bool {
204-
false
203+
fn op_type(&self) -> OpType {
204+
OpType::Blocking
205205
}
206206

207207
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -293,8 +293,8 @@ impl FileStat {
293293
}
294294

295295
impl OpCode for FileStat {
296-
fn is_overlapped(&self) -> bool {
297-
false
296+
fn op_type(&self) -> OpType {
297+
OpType::Blocking
298298
}
299299

300300
unsafe fn operate(mut self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -378,8 +378,8 @@ impl PathStat {
378378
}
379379

380380
impl OpCode for PathStat {
381-
fn is_overlapped(&self) -> bool {
382-
false
381+
fn op_type(&self) -> OpType {
382+
OpType::Blocking
383383
}
384384

385385
unsafe fn operate(mut self: Pin<&mut Self>, optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -473,8 +473,8 @@ impl<T: IoBuf> OpCode for WriteAt<T> {
473473
}
474474

475475
impl OpCode for Sync {
476-
fn is_overlapped(&self) -> bool {
477-
false
476+
fn op_type(&self) -> OpType {
477+
OpType::Blocking
478478
}
479479

480480
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -483,8 +483,8 @@ impl OpCode for Sync {
483483
}
484484

485485
impl OpCode for ShutdownSocket {
486-
fn is_overlapped(&self) -> bool {
487-
false
486+
fn op_type(&self) -> OpType {
487+
OpType::Blocking
488488
}
489489

490490
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -498,8 +498,8 @@ impl OpCode for ShutdownSocket {
498498
}
499499

500500
impl OpCode for CloseSocket {
501-
fn is_overlapped(&self) -> bool {
502-
false
501+
fn op_type(&self) -> OpType {
502+
OpType::Blocking
503503
}
504504

505505
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {

compio-fs/src/stdio/windows.rs

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ use std::{
99
use compio_buf::{BufResult, IntoInner, IoBuf, IoBufMut};
1010
use compio_driver::{
1111
op::{BufResultExt, Recv, Send},
12-
AsRawFd, OpCode, RawFd,
12+
AsRawFd, OpCode, OpType, RawFd,
1313
};
1414
use compio_io::{AsyncRead, AsyncWrite};
1515
use compio_runtime::Runtime;
@@ -30,8 +30,8 @@ impl<R: Read, B: IoBufMut> StdRead<R, B> {
3030
}
3131

3232
impl<R: Read, B: IoBufMut> OpCode for StdRead<R, B> {
33-
fn is_overlapped(&self) -> bool {
34-
false
33+
fn op_type(&self) -> OpType {
34+
OpType::Blocking
3535
}
3636

3737
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {
@@ -79,8 +79,8 @@ impl<W: Write, B: IoBuf> StdWrite<W, B> {
7979
}
8080

8181
impl<W: Write, B: IoBuf> OpCode for StdWrite<W, B> {
82-
fn is_overlapped(&self) -> bool {
83-
false
82+
fn op_type(&self) -> OpType {
83+
OpType::Blocking
8484
}
8585

8686
unsafe fn operate(self: Pin<&mut Self>, _optr: *mut OVERLAPPED) -> Poll<io::Result<usize>> {

0 commit comments

Comments
 (0)