Skip to content

Commit

Permalink
Make cast method produced by implement macro unsafe (#1753)
Browse files Browse the repository at this point in the history
* Remove broken cast method on types using implement macro

* Simplify From impls from implement macro

* Free implementation in alloc if querying fails

* Add back cast but make it unsafe

* Remove alloc
  • Loading branch information
rylev authored May 11, 2022
1 parent caecac5 commit c178210
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 27 deletions.
40 changes: 20 additions & 20 deletions crates/libs/implement/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,10 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for #interface_ident {
fn from(this: #original_ident::<#(#generics,)*>) -> Self {
let this = #impl_ident::<#(#generics,)*>::new(this);
let mut this = ::std::boxed::Box::new(this);
let vtable_ptr = &mut this.vtables.#offset as *mut *const <#interface_ident as ::windows::core::Interface>::Vtable;
let _ = ::std::boxed::Box::leak(this);
unsafe { ::core::mem::transmute_copy(&vtable_ptr) }
let mut this = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this));
let vtable_ptr = &this.vtables.#offset;
// SAFETY: interfaces are in-memory equivalent to pointers to their vtables.
unsafe { ::core::mem::transmute(vtable_ptr) }
}
}
impl <#constraints> ::windows::core::AsImpl<#original_ident::<#(#generics,)*>> for #interface_ident {
Expand Down Expand Up @@ -145,12 +145,16 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
}
impl <#constraints> #original_ident::<#(#generics,)*> {
fn cast<ResultType: ::windows::core::Interface>(&self) -> ::windows::core::Result<ResultType> {
unsafe {
let boxed = (self as *const #original_ident::<#(#generics,)*> as *mut #original_ident::<#(#generics,)*> as *mut ::windows::core::RawPtr).sub(2 + #interfaces_len) as *mut #impl_ident::<#(#generics,)*>;
let mut result = None;
<#impl_ident::<#(#generics,)*> as ::windows::core::IUnknownImpl>::QueryInterface(&*boxed, &ResultType::IID, &mut result as *mut _ as _).and_some(result)
}
/// Try casting as the provided interface
///
/// # Safety
///
/// This function can only be safely called if `self` has been heap allocated and pinned using
/// the mechanisms provided by `implement` macro.
unsafe fn cast<I: ::windows::core::Interface>(&self) -> ::windows::core::Result<I> {
let boxed = (self as *const _ as *const ::windows::core::RawPtr).sub(2 + #interfaces_len) as *mut #impl_ident::<#(#generics,)*>;
let mut result = None;
<#impl_ident::<#(#generics,)*> as ::windows::core::IUnknownImpl>::QueryInterface(&*boxed, &I::IID, &mut result as *mut _ as _).and_some(result)
}
}
impl <#constraints> ::windows::core::Compose for #original_ident::<#(#generics,)*> {
Expand All @@ -163,23 +167,19 @@ pub fn implement(attributes: proc_macro::TokenStream, original_type: proc_macro:
}
impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for ::windows::core::IUnknown {
fn from(this: #original_ident::<#(#generics,)*>) -> Self {
let this = #impl_ident::<#(#generics,)*>::new(this);
let boxed = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this));
unsafe {
let this = #impl_ident::<#(#generics,)*>::new(this);
let ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(this));
::core::mem::transmute_copy(&::core::ptr::NonNull::new_unchecked(
&mut (*ptr).identity as *mut _ as _
))
::core::mem::transmute(&boxed.identity)
}
}
}
impl <#constraints> ::core::convert::From<#original_ident::<#(#generics,)*>> for ::windows::core::IInspectable {
fn from(this: #original_ident::<#(#generics,)*>) -> Self {
let this = #impl_ident::<#(#generics,)*>::new(this);
let boxed = ::core::mem::ManuallyDrop::new(::std::boxed::Box::new(this));
unsafe {
let this = #impl_ident::<#(#generics,)*>::new(this);
let ptr = ::std::boxed::Box::into_raw(::std::boxed::Box::new(this));
::core::mem::transmute_copy(&::core::ptr::NonNull::new_unchecked(
&mut (*ptr).identity as *mut _ as _
))
::core::mem::transmute(&boxed.identity)
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/tests/nightly_implement/tests/cast_self.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@ use windows::UI::Xaml::*;
// TODO: This is a compile-only test for now until #81 is further along and can provide composable test classes.

#[implement(IApplicationOverrides)]
struct App();
struct App;

#[allow(non_snake_case)]
impl IApplicationOverrides_Impl for App {
fn OnLaunched(&self, _: &Option<LaunchActivatedEventArgs>) -> Result<()> {
let app: Application = self.cast()?;
let app: Application = unsafe { self.cast()? };
assert!(app.FocusVisualKind()? == FocusVisualKind::DottedLine);
Ok(())
}
Expand Down
6 changes: 3 additions & 3 deletions crates/tests/nightly_implement/tests/com.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use windows::Win32::System::WinRT::Composition::*;
use windows::Win32::System::WinRT::Display::*;

#[implement(windows::Foundation::IStringable, windows::Win32::System::WinRT::Composition::ISwapChainInterop, windows::Win32::System::WinRT::Display::IDisplayPathInterop)]
struct Mix();
struct Mix;

impl IStringable_Impl for Mix {
fn ToString(&self) -> Result<HSTRING> {
Expand All @@ -32,13 +32,13 @@ impl IDisplayPathInterop_Impl for Mix {

#[test]
fn mix() -> Result<()> {
let a: ISwapChainInterop = Mix().into();
let a: ISwapChainInterop = Mix.into();
unsafe { a.SetSwapChain(None)? };

let b: IStringable = a.cast()?;
assert!(b.ToString()? == "Mix");

let c: IStringable = Mix().into();
let c: IStringable = Mix.into();
assert!(c.ToString()? == "Mix");

let d: ISwapChainInterop = c.cast()?;
Expand Down
2 changes: 1 addition & 1 deletion crates/tests/nightly_implement/tests/into_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ where
#[allow(non_snake_case)]
impl<T: RuntimeType + 'static> IIterable_Impl<T> for Iterable<T> {
fn First(&self) -> Result<IIterator<T>> {
Ok(Iterator::<T>((self.cast()?, 0).into()).into())
Ok(Iterator::<T>((unsafe { self.cast()? }, 0).into()).into())
}
}

Expand Down
2 changes: 1 addition & 1 deletion crates/tests/nightly_vector/tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ impl<T: ::windows::core::RuntimeType + 'static> IVector_Impl<T> for Vector<T> {
self.Size()
}
fn GetView(&self) -> Result<windows::Foundation::Collections::IVectorView<T>> {
self.cast()
unsafe { self.cast() }
}
fn IndexOf(&self, value: &T::DefaultType, result: &mut u32) -> Result<bool> {
self.IndexOf(value, result)
Expand Down

0 comments on commit c178210

Please sign in to comment.