Skip to content

Commit

Permalink
Type-specific handle validity checking (#1648)
Browse files Browse the repository at this point in the history
  • Loading branch information
kennykerr authored Mar 31, 2022
1 parent 3329856 commit ab6994f
Show file tree
Hide file tree
Showing 109 changed files with 1,005 additions and 2,664 deletions.
2 changes: 1 addition & 1 deletion .github/readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ use windows_sys::{

fn main() {
unsafe {
let event = CreateEventW(std::ptr::null(), 1, 0, std::ptr::null());
let event = CreateEventW(std::ptr::null(), 1, 0, std::ptr::null())?;
SetEvent(event);
WaitForSingleObject(event, 0);
CloseHandle(event);
Expand Down
2 changes: 1 addition & 1 deletion crates/libs/bindgen/src/async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ fn gen_async_kind(kind: AsyncKind, name: &TypeDef, self_name: &TypeDef, cfg: &Cf
impl<#(#constraints)*> #name {
pub fn get(&self) -> ::windows::core::Result<#return_type> {
if self.Status()? == #namespace AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(#namespace #handler::new(move |_sender, _args| {
// Safe because the waiter will only be dropped after being signaled.
unsafe { signaler.signal(); }
Expand Down
73 changes: 58 additions & 15 deletions crates/libs/bindgen/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -203,24 +203,49 @@ fn gen_win_function(def: &MethodDef, gen: &Gen) -> TokenStream {
}
}
SignatureKind::ReturnStruct | SignatureKind::PreserveSig => {
let args = gen_win32_args(&signature.params);
let params = gen_win32_params(&signature.params, gen);
if handle_last_error(def, &signature) {
let args = gen_win32_args(&signature.params);
let params = gen_win32_params(&signature.params, gen);
let return_type = gen_element_name(&signature.return_type.unwrap(), gen);

quote! {
#doc
#features
#[inline]
pub unsafe fn #name<#constraints>(#params) #abi_return_type {
#[cfg(windows)]
{
#link_attr
extern "system" {
fn #name(#(#abi_params),*) #abi_return_type;
quote! {
#doc
#features
#[inline]
pub unsafe fn #name<#constraints>(#params) -> ::windows::core::Result<#return_type> {
#[cfg(windows)]
{
#link_attr
extern "system" {
fn #name(#(#abi_params),*) -> #return_type;
}
let result__ = #name(#args);
(!result__.is_invalid()).then(||result__).ok_or_else(::windows::core::Error::from_win32)
}
::core::mem::transmute(#name(#args))
#[cfg(not(windows))]
unimplemented!("Unsupported target OS");
}
}
} else {
let args = gen_win32_args(&signature.params);
let params = gen_win32_params(&signature.params, gen);

quote! {
#doc
#features
#[inline]
pub unsafe fn #name<#constraints>(#params) #abi_return_type {
#[cfg(windows)]
{
#link_attr
extern "system" {
fn #name(#(#abi_params),*) #abi_return_type;
}
::core::mem::transmute(#name(#args))
}
#[cfg(not(windows))]
unimplemented!("Unsupported target OS");
}
#[cfg(not(windows))]
unimplemented!("Unsupported target OS");
}
}
}
Expand Down Expand Up @@ -257,3 +282,21 @@ fn does_not_return(def: &MethodDef) -> TokenStream {
quote! {}
}
}

fn handle_last_error(def: &MethodDef, signature: &Signature) -> bool {
if let Some(map) = def.impl_map() {
if map.flags().last_error() {
if let Some(Type::TypeDef(return_type)) = &signature.return_type {
if return_type.is_handle() {
if return_type.underlying_type().is_pointer() {
return true;
}
if !return_type.invalid_values().is_empty() {
return true;
}
}
}
}
}
false
}
51 changes: 31 additions & 20 deletions crates/libs/bindgen/src/handles.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ pub fn gen(def: &TypeDef, gen: &Gen) -> TokenStream {

pub fn gen_sys_handle(def: &TypeDef, gen: &Gen) -> TokenStream {
let ident = gen_ident(def.name());
let signature = gen_signature(def, gen);
let signature = gen_default_type(&def.underlying_type(), gen);

quote! {
pub type #ident = #signature;
Expand All @@ -20,26 +20,42 @@ pub fn gen_sys_handle(def: &TypeDef, gen: &Gen) -> TokenStream {
pub fn gen_win_handle(def: &TypeDef, gen: &Gen) -> TokenStream {
let name = def.name();
let ident = gen_ident(def.name());
let signature = gen_signature(def, gen);
let underlying_type = def.underlying_type();
let signature = gen_default_type(&underlying_type, gen);
let check = if underlying_type.is_pointer() {
quote! {
impl #ident {
pub fn is_invalid(&self) -> bool {
self.0.is_null()
}
}
}
} else {
let invalid = def.invalid_values();

if !invalid.is_empty() {
let invalid = invalid.iter().map(|value| {
let value = Literal::i64_unsuffixed(*value);
quote! { self.0 == #value }
});
quote! {
impl #ident {
pub fn is_invalid(&self) -> bool {
#(#invalid)||*
}
}
}
} else {
quote! {}
}
};

let mut tokens = quote! {
#[repr(transparent)]
// Unfortunately, Rust requires these to be derived to allow constant patterns.
#[derive(::core::cmp::PartialEq, ::core::cmp::Eq)]
pub struct #ident(pub #signature);
impl #ident {
pub fn is_invalid(&self) -> bool {
*self == unsafe { ::core::mem::zeroed() }
}

pub fn ok(self) -> ::windows::core::Result<Self> {
if !self.is_invalid() {
Ok(self)
} else {
Err(::windows::core::Error::from_win32())
}
}
}
#check
impl ::core::default::Default for #ident {
fn default() -> Self {
unsafe { ::core::mem::zeroed() }
Expand Down Expand Up @@ -77,8 +93,3 @@ pub fn gen_win_handle(def: &TypeDef, gen: &Gen) -> TokenStream {

tokens
}

fn gen_signature(def: &TypeDef, gen: &Gen) -> TokenStream {
let def = def.fields().next().map(|field| field.get_type(Some(def))).unwrap();
gen_default_type(&def, gen)
}
46 changes: 0 additions & 46 deletions crates/libs/bindgen/src/replacements/handle.rs

This file was deleted.

2 changes: 0 additions & 2 deletions crates/libs/bindgen/src/replacements/mod.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
use super::*;
mod bool32;
mod bstr;
mod handle;
mod ntstatus;

pub fn gen(def: &TypeDef) -> Option<TokenStream> {
match def.type_name() {
TypeName::BOOL => Some(bool32::gen()),
TypeName::BSTR => Some(bstr::gen()),
TypeName::NTSTATUS => Some(ntstatus::gen()),
TypeName::HANDLE => Some(handle::gen()),
_ => None,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,12 @@ impl ParamFlags {
self.0 & 0x0010 != 0
}
}

#[derive(Default)]
pub struct PInvokeAttributes(pub u32);

impl PInvokeAttributes {
pub fn last_error(&self) -> bool {
self.0 & 0x0040 != 0
}
}
4 changes: 2 additions & 2 deletions crates/libs/metadata/src/reader/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@ mod cfg;
mod codes;
mod constant_value;
mod file;
mod flags;
mod guid;
mod interface_kind;
mod param_flags;
mod row;
mod signature;
mod signature_kind;
Expand All @@ -27,9 +27,9 @@ pub use cfg::*;
pub use codes::*;
pub use constant_value::*;
pub use file::*;
pub use flags::*;
pub use guid::*;
pub use interface_kind::*;
pub use param_flags::*;
pub use r#type::*;
pub use row::*;
pub use signature::*;
Expand Down
4 changes: 4 additions & 0 deletions crates/libs/metadata/src/reader/tables/impl_map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@ use super::*;
pub struct ImplMap(pub Row);

impl ImplMap {
pub fn flags(&self) -> PInvokeAttributes {
PInvokeAttributes(self.0.u32(0))
}

pub fn scope(&self) -> ModuleRef {
ModuleRef(Row::new(self.0.u32(3) - 1, TableIndex::ModuleRef, self.0.file))
}
Expand Down
13 changes: 13 additions & 0 deletions crates/libs/metadata/src/reader/tables/type_def.rs
Original file line number Diff line number Diff line change
Expand Up @@ -450,6 +450,19 @@ impl TypeDef {
})
}

pub fn invalid_values(&self) -> Vec<i64> {
self.attributes()
.filter_map(|attribute| {
if attribute.name() == "InvalidHandleValueAttribute" {
if let Some((_, ConstantValue::I64(value))) = attribute.args().get(0) {
return Some(*value);
}
}
None
})
.collect()
}

pub fn is_convertible_to(&self) -> Option<&Type> {
self.attributes().find_map(|attribute| {
if attribute.name() == "AlsoUsableForAttribute" {
Expand Down
1 change: 0 additions & 1 deletion crates/libs/metadata/src/reader/type_name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ impl TypeName {
pub const PWSTR: Self = Self::from_const("Windows.Win32.Foundation", "PWSTR");
pub const PSTR: Self = Self::from_const("Windows.Win32.Foundation", "PSTR");
pub const BSTR: Self = Self::from_const("Windows.Win32.Foundation", "BSTR");
pub const HANDLE: Self = Self::from_const("Windows.Win32.Foundation", "HANDLE");
pub const HRESULT: Self = Self::from_const("Windows.Win32.Foundation", "HRESULT");
pub const D2D_MATRIX_3X2_F: Self = Self::from_const("Windows.Win32.Graphics.Direct2D.Common", "D2D_MATRIX_3X2_F");
pub const IUnknown: Self = Self::from_const("Windows.Win32.System.Com", "IUnknown");
Expand Down
1 change: 1 addition & 0 deletions crates/libs/tokens/src/token_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ macro_rules! unsuffixed {
}

impl Literal {
unsuffixed!(i64 => i64_unsuffixed);
unsuffixed!(usize => usize_unsuffixed);
unsuffixed!(u32 => u32_unsuffixed);
unsuffixed!(u16 => u16_unsuffixed);
Expand Down
12 changes: 6 additions & 6 deletions crates/libs/windows/src/Windows/Devices/Sms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ impl ::windows::core::RuntimeName for DeleteSmsMessageOperation {
impl DeleteSmsMessageOperation {
pub fn get(&self) -> ::windows::core::Result<()> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -377,7 +377,7 @@ impl ::windows::core::RuntimeName for DeleteSmsMessagesOperation {
impl DeleteSmsMessagesOperation {
pub fn get(&self) -> ::windows::core::Result<()> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -615,7 +615,7 @@ impl ::windows::core::RuntimeName for GetSmsDeviceOperation {
impl GetSmsDeviceOperation {
pub fn get(&self) -> ::windows::core::Result<SmsDevice> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -853,7 +853,7 @@ impl ::windows::core::RuntimeName for GetSmsMessageOperation {
impl GetSmsMessageOperation {
pub fn get(&self) -> ::windows::core::Result<ISmsMessage> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncOperationCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -1106,7 +1106,7 @@ impl ::windows::core::RuntimeName for GetSmsMessagesOperation {
impl GetSmsMessagesOperation {
pub fn get(&self) -> ::windows::core::Result<super::super::Foundation::Collections::IVectorView<ISmsMessage>> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncOperationWithProgressCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down Expand Up @@ -2830,7 +2830,7 @@ impl ::windows::core::RuntimeName for SendSmsMessageOperation {
impl SendSmsMessageOperation {
pub fn get(&self) -> ::windows::core::Result<()> {
if self.Status()? == super::super::Foundation::AsyncStatus::Started {
let (_waiter, signaler) = ::windows::core::Waiter::new();
let (_waiter, signaler) = ::windows::core::Waiter::new()?;
self.SetCompleted(super::super::Foundation::AsyncActionCompletedHandler::new(move |_sender, _args| {
unsafe {
signaler.signal();
Expand Down
Loading

0 comments on commit ab6994f

Please sign in to comment.