From a8fda870582a1049bdc31284bc3bb82969014895 Mon Sep 17 00:00:00 2001 From: Maximilian Hils Date: Sat, 18 Feb 2023 20:03:16 +0100 Subject: [PATCH] net: use Message Read Mode for named pipes (#5350) --- tokio/src/net/windows/named_pipe.rs | 44 +++++++++++++++++++++++--- tokio/tests/net_named_pipe.rs | 48 ++++++++++++++++++++++++----- 2 files changed, 80 insertions(+), 12 deletions(-) diff --git a/tokio/src/net/windows/named_pipe.rs b/tokio/src/net/windows/named_pipe.rs index 9ede94ea6a0..2107c1cdfce 100644 --- a/tokio/src/net/windows/named_pipe.rs +++ b/tokio/src/net/windows/named_pipe.rs @@ -1701,14 +1701,18 @@ impl ServerOptions { /// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for /// documentation of what each mode means. /// - /// This corresponding to specifying [`dwPipeMode`]. + /// This corresponds to specifying `PIPE_TYPE_` and `PIPE_READMODE_` in [`dwPipeMode`]. /// /// [`dwPipeMode`]: https://docs.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-createnamedpipea pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self { let is_msg = matches!(pipe_mode, PipeMode::Message); // Pipe mode is implemented as a bit flag 0x4. Set is message and unset // is byte. - bool_flag!(self.pipe_mode, is_msg, windows_sys::PIPE_TYPE_MESSAGE); + bool_flag!( + self.pipe_mode, + is_msg, + windows_sys::PIPE_TYPE_MESSAGE | windows_sys::PIPE_READMODE_MESSAGE + ); self } @@ -2268,6 +2272,7 @@ impl ServerOptions { pub struct ClientOptions { desired_access: u32, security_qos_flags: u32, + pipe_mode: PipeMode, } impl ClientOptions { @@ -2289,6 +2294,7 @@ impl ClientOptions { desired_access: windows_sys::GENERIC_READ | windows_sys::GENERIC_WRITE, security_qos_flags: windows_sys::SECURITY_IDENTIFICATION | windows_sys::SECURITY_SQOS_PRESENT, + pipe_mode: PipeMode::Byte, } } @@ -2341,6 +2347,15 @@ impl ClientOptions { self } + /// The pipe mode. + /// + /// The default pipe mode is [`PipeMode::Byte`]. See [`PipeMode`] for + /// documentation of what each mode means. + pub fn pipe_mode(&mut self, pipe_mode: PipeMode) -> &mut Self { + self.pipe_mode = pipe_mode; + self + } + /// Opens the named pipe identified by `addr`. /// /// This opens the client using [`CreateFile`] with the @@ -2437,6 +2452,20 @@ impl ClientOptions { return Err(io::Error::last_os_error()); } + if matches!(self.pipe_mode, PipeMode::Message) { + let mut mode = windows_sys::PIPE_READMODE_MESSAGE; + let result = windows_sys::SetNamedPipeHandleState( + h, + &mut mode, + ptr::null_mut(), + ptr::null_mut(), + ); + + if result == 0 { + return Err(io::Error::last_os_error()); + } + } + NamedPipeClient::from_raw_handle(h as _) } @@ -2556,7 +2585,9 @@ unsafe fn named_pipe_info(handle: RawHandle) -> io::Result { #[cfg(test)] mod test { - use self::windows_sys::{PIPE_REJECT_REMOTE_CLIENTS, PIPE_TYPE_BYTE, PIPE_TYPE_MESSAGE}; + use self::windows_sys::{ + PIPE_READMODE_MESSAGE, PIPE_REJECT_REMOTE_CLIENTS, PIPE_TYPE_BYTE, PIPE_TYPE_MESSAGE, + }; use super::*; #[test] @@ -2588,13 +2619,16 @@ mod test { opts.reject_remote_clients(false); opts.pipe_mode(PipeMode::Message); - assert_eq!(opts.pipe_mode, PIPE_TYPE_MESSAGE); + assert_eq!(opts.pipe_mode, PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE); opts.reject_remote_clients(true); opts.pipe_mode(PipeMode::Message); assert_eq!( opts.pipe_mode, - PIPE_TYPE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS + PIPE_TYPE_MESSAGE | PIPE_READMODE_MESSAGE | PIPE_REJECT_REMOTE_CLIENTS ); + + opts.pipe_mode(PipeMode::Byte); + assert_eq!(opts.pipe_mode, PIPE_TYPE_BYTE | PIPE_REJECT_REMOTE_CLIENTS); } } diff --git a/tokio/tests/net_named_pipe.rs b/tokio/tests/net_named_pipe.rs index c42122465c0..3ddc4c8a9bf 100644 --- a/tokio/tests/net_named_pipe.rs +++ b/tokio/tests/net_named_pipe.rs @@ -5,7 +5,7 @@ use std::io; use std::mem; use std::os::windows::io::AsRawHandle; use std::time::Duration; -use tokio::io::AsyncWriteExt; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; use tokio::net::windows::named_pipe::{ClientOptions, PipeMode, ServerOptions}; use tokio::time; use windows_sys::Win32::Foundation::{ERROR_NO_DATA, ERROR_PIPE_BUSY, NO_ERROR, UNICODE_STRING}; @@ -327,17 +327,51 @@ async fn test_named_pipe_multi_client_ready() -> io::Result<()> { Ok(()) } -// This tests what happens when a client tries to disconnect. +// This tests that message mode works as expected. #[tokio::test] async fn test_named_pipe_mode_message() -> io::Result<()> { - const PIPE_NAME: &str = r"\\.\pipe\test-named-pipe-mode-message"; + // it's easy to accidentally get a seemingly working test here because byte pipes + // often return contents at write boundaries. to make sure we're doing the right thing we + // explicitly test that it doesn't work in byte mode. + _named_pipe_mode_message(PipeMode::Message).await?; + _named_pipe_mode_message(PipeMode::Byte).await +} + +async fn _named_pipe_mode_message(mode: PipeMode) -> io::Result<()> { + let pipe_name = format!( + r"\\.\pipe\test-named-pipe-mode-message-{}", + matches!(mode, PipeMode::Message) + ); + let mut buf = [0u8; 32]; - let server = ServerOptions::new() - .pipe_mode(PipeMode::Message) - .create(PIPE_NAME)?; + let mut server = ServerOptions::new() + .first_pipe_instance(true) + .pipe_mode(mode) + .create(&pipe_name)?; + + let mut client = ClientOptions::new().pipe_mode(mode).open(&pipe_name)?; - let _ = ClientOptions::new().open(PIPE_NAME)?; server.connect().await?; + + // this needs a few iterations, presumably Windows waits for a few calls before merging buffers + for _ in 0..10 { + client.write_all(b"hello").await?; + server.write_all(b"world").await?; + } + for _ in 0..10 { + let n = server.read(&mut buf).await?; + if buf[..n] != b"hello"[..] { + assert!(matches!(mode, PipeMode::Byte)); + return Ok(()); + } + let n = client.read(&mut buf).await?; + if buf[..n] != b"world"[..] { + assert!(matches!(mode, PipeMode::Byte)); + return Ok(()); + } + } + // byte mode should have errored before. + assert!(matches!(mode, PipeMode::Message)); Ok(()) }