Skip to content

Commit

Permalink
net: use Message Read Mode for named pipes (#5350)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhils authored Feb 18, 2023
1 parent d7abdbb commit a8fda87
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 12 deletions.
44 changes: 39 additions & 5 deletions tokio/src/net/windows/named_pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1701,14 +1701,18 @@ impl ServerOptions {
/// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for
/// documentation of what each mode means.
///
/// This corresponding to specifying [`dwPipeMode`].
/// This corresponds to specifying `PIPE_TYPE_` and `PIPE_READMODE_` in [`dwPipeMode`].
///
/// [`dwPipeMode`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea
pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self {
let is_msg = matches!(pipe_mode, PipeMode::Message);
// Pipe mode is implemented as a bit flag 0x4. Set is message and unset
// is byte.
bool_flag!(self.pipe_mode, is_msg, windows_sys::PIPE_TYPE_MESSAGE);
bool_flag!(
self.pipe_mode,
is_msg,
windows_sys::PIPE_TYPE_MESSAGE | windows_sys::PIPE_READMODE_MESSAGE
);
self
}

Expand Down Expand Up @@ -2268,6 +2272,7 @@ impl ServerOptions {
pub struct ClientOptions {
desired_access: u32,
security_qos_flags: u32,
pipe_mode: PipeMode,
}

impl ClientOptions {
Expand All @@ -2289,6 +2294,7 @@ impl ClientOptions {
desired_access: windows_sys::GENERIC_READ | windows_sys::GENERIC_WRITE,
security_qos_flags: windows_sys::SECURITY_IDENTIFICATION
| windows_sys::SECURITY_SQOS_PRESENT,
pipe_mode: PipeMode::Byte,
}
}

Expand Down Expand Up @@ -2341,6 +2347,15 @@ impl ClientOptions {
self
}

/// The pipe mode.
///
/// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for
/// documentation of what each mode means.
pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self {
self.pipe_mode = pipe_mode;
self
}

/// Opens the named pipe identified by `addr`.
///
/// This opens the client using [`CreateFile`] with the
Expand Down Expand Up @@ -2437,6 +2452,20 @@ impl ClientOptions {
return Err(io::Error::last_os_error());
}

if matches!(self.pipe_mode, PipeMode::Message) {
let mut mode = windows_sys::PIPE_READMODE_MESSAGE;
let result = windows_sys::SetNamedPipeHandleState(
h,
&mut mode,
ptr::null_mut(),
ptr::null_mut(),
);

if result == 0 {
return Err(io::Error::last_os_error());
}
}

NamedPipeClient::from_raw_handle(h as _)
}

Expand Down Expand Up @@ -2556,7 +2585,9 @@ unsafe fn named_pipe_info(handle: RawHandle) -> io::Result<PipeInfo> {

#[cfg(test)]
mod test {
use self::windows_sys::{PIPE_REJECT_REMOTE_CLIENTS, PIPE_TYPE_BYTE, PIPE_TYPE_MESSAGE};
use self::windows_sys::{
PIPE_READMODE_MESSAGE, PIPE_REJECT_REMOTE_CLIENTS, PIPE_TYPE_BYTE, PIPE_TYPE_MESSAGE,
};
use super::*;

#[test]
Expand Down Expand Up @@ -2588,13 +2619,16 @@ mod test {

opts.reject_remote_clients(false);
opts.pipe_mode(PipeMode::Message);
assert_eq!(opts.pipe_mode, PIPE_TYPE_MESSAGE);
assert_eq!(opts.pipe_mode, PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE);

opts.reject_remote_clients(true);
opts.pipe_mode(PipeMode::Message);
assert_eq!(
opts.pipe_mode,
PIPE_TYPE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS
PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS
);

opts.pipe_mode(PipeMode::Byte);
assert_eq!(opts.pipe_mode, PIPE_TYPE_BYTE | PIPE_REJECT_REMOTE_CLIENTS);
}
}
48 changes: 41 additions & 7 deletions tokio/tests/net_named_pipe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::io;
use std::mem;
use std::os::windows::io::AsRawHandle;
use std::time::Duration;
use tokio::io::AsyncWriteExt;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::windows::named_pipe::{ClientOptions, PipeMode, ServerOptions};
use tokio::time;
use windows_sys::Win32::Foundation::{ERROR_NO_DATA, ERROR_PIPE_BUSY, NO_ERROR, UNICODE_STRING};
Expand Down Expand Up @@ -327,17 +327,51 @@ async fn test_named_pipe_multi_client_ready() -> io::Result<()> {
Ok(())
}

// This tests what happens when a client tries to disconnect.
// This tests that message mode works as expected.
#[tokio::test]
async fn test_named_pipe_mode_message() -> io::Result<()> {
const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-mode-message";
// it's easy to accidentally get a seemingly working test here because byte pipes
// often return contents at write boundaries. to make sure we're doing the right thing we
// explicitly test that it doesn't work in byte mode.
_named_pipe_mode_message(PipeMode::Message).await?;
_named_pipe_mode_message(PipeMode::Byte).await
}

async fn _named_pipe_mode_message(mode: PipeMode) -> io::Result<()> {
let pipe_name = format!(
r"\\.\pipe\test-named-pipe-mode-message-{}",
matches!(mode, PipeMode::Message)
);
let mut buf = [0u8; 32];

let server = ServerOptions::new()
.pipe_mode(PipeMode::Message)
.create(PIPE_NAME)?;
let mut server = ServerOptions::new()
.first_pipe_instance(true)
.pipe_mode(mode)
.create(&pipe_name)?;

let mut client = ClientOptions::new().pipe_mode(mode).open(&pipe_name)?;

let _ = ClientOptions::new().open(PIPE_NAME)?;
server.connect().await?;

// this needs a few iterations, presumably Windows waits for a few calls before merging buffers
for _ in 0..10 {
client.write_all(b"hello").await?;
server.write_all(b"world").await?;
}
for _ in 0..10 {
let n = server.read(&mut buf).await?;
if buf[..n] != b"hello"[..] {
assert!(matches!(mode, PipeMode::Byte));
return Ok(());
}
let n = client.read(&mut buf).await?;
if buf[..n] != b"world"[..] {
assert!(matches!(mode, PipeMode::Byte));
return Ok(());
}
}
// byte mode should have errored before.
assert!(matches!(mode, PipeMode::Message));
Ok(())
}

Expand Down

0 comments on commit a8fda87

Please sign in to comment.