Skip to content

Commit

Permalink
Merge pull request #46 from spastorino/non-future-impl-traits
Browse files Browse the repository at this point in the history
Properly support non future impl traits
  • Loading branch information
tmandry authored Feb 6, 2025
2 parents 8d8d17f + b9fb5ae commit d1bde14
Show file tree
Hide file tree
Showing 7 changed files with 283 additions and 112 deletions.
17 changes: 8 additions & 9 deletions dynosaur/tests/pass/basic-rpitit.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,15 @@ mod _dynosaur_macro_dynmytrait {
use super::*;
trait ErasedMyTrait {
fn foo<'life0, 'dynosaur>(&'life0 self)
-> ::core::pin::Pin<Box<dyn Send + 'dynosaur>>
-> Box<dyn Send + 'dynosaur>
where
'life0: 'dynosaur,
Self: 'dynosaur;
}
impl<DYNOSAUR: MyTrait> ErasedMyTrait for DYNOSAUR {
fn foo<'life0, 'dynosaur>(&'life0 self)
-> ::core::pin::Pin<Box<dyn Send + 'dynosaur>> where
'life0: 'dynosaur, Self: 'dynosaur {
Box::pin(<Self as MyTrait>::foo(self))
fn foo<'life0, 'dynosaur>(&'life0 self) -> Box<dyn Send + 'dynosaur>
where 'life0: 'dynosaur, Self: 'dynosaur {
Box::new(<Self as MyTrait>::foo(self))
}
}
#[repr(transparent)]
Expand All @@ -31,10 +30,10 @@ mod _dynosaur_macro_dynmytrait {
}
impl<'dynosaur_struct> MyTrait for DynMyTrait<'dynosaur_struct> {
fn foo(&self) -> impl Send {
let fut: ::core::pin::Pin<Box<dyn Send + '_>> = self.ptr.foo();
let fut: ::core::pin::Pin<Box<dyn Send + 'static>> =
unsafe { ::core::mem::transmute(fut) };
fut
let ret: Box<dyn Send + '_> = self.ptr.foo();
let ret: Box<dyn Send + '_> =
unsafe { ::core::mem::transmute(ret) };
ret
}
}
impl<'dynosaur_struct> DynMyTrait<'dynosaur_struct> {
Expand Down
16 changes: 8 additions & 8 deletions dynosaur/tests/pass/default-method-trait.stdout
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,16 @@ mod _dynosaur_macro_dynmytrait {
use super::*;
trait ErasedMyTrait {
fn foo<'life0, 'dynosaur>(&'life0 mut self)
-> ::core::pin::Pin<Box<dyn Send + 'dynosaur>>
-> Box<dyn Send + 'dynosaur>
where
'life0: 'dynosaur,
Self: 'dynosaur;
}
impl<DYNOSAUR: MyTrait> ErasedMyTrait for DYNOSAUR {
fn foo<'life0, 'dynosaur>(&'life0 mut self)
-> ::core::pin::Pin<Box<dyn Send + 'dynosaur>> where
'life0: 'dynosaur, Self: 'dynosaur {
Box::pin(<Self as MyTrait>::foo(self))
-> Box<dyn Send + 'dynosaur> where 'life0: 'dynosaur,
Self: 'dynosaur {
Box::new(<Self as MyTrait>::foo(self))
}
}
#[repr(transparent)]
Expand All @@ -30,10 +30,10 @@ mod _dynosaur_macro_dynmytrait {
}
impl<'dynosaur_struct> MyTrait for DynMyTrait<'dynosaur_struct> {
fn foo(&mut self) -> impl Send {
let fut: ::core::pin::Pin<Box<dyn Send + '_>> = self.ptr.foo();
let fut: ::core::pin::Pin<Box<dyn Send + 'static>> =
unsafe { ::core::mem::transmute(fut) };
fut
let ret: Box<dyn Send + '_> = self.ptr.foo();
let ret: Box<dyn Send + '_> =
unsafe { ::core::mem::transmute(ret) };
ret
}
}
impl<'dynosaur_struct> DynMyTrait<'dynosaur_struct> {
Expand Down
19 changes: 19 additions & 0 deletions dynosaur/tests/pass/non-future-impl-traits.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
#[dynosaur::dynosaur(DynSomeTrait)]
trait SomeTrait {
fn get_iter(&mut self) -> impl Iterator<Item = u8> + '_;
}

struct MyImpl([u8; 4]);

impl SomeTrait for MyImpl {
fn get_iter(&mut self) -> impl Iterator<Item = u8> + '_ {
return self.0.into_iter();
}
}

fn main() {
let mut st = DynSomeTrait::boxed(MyImpl([3,2,4,1]));
for x in st.get_iter() {
println!("{}", x);
}
}
73 changes: 73 additions & 0 deletions dynosaur/tests/pass/non-future-impl-traits.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#![feature(prelude_import)]
#[prelude_import]
use std::prelude::rust_2021::*;
#[macro_use]
extern crate std;
trait SomeTrait {
fn get_iter(&mut self)
-> impl Iterator<Item = u8> + '_;
}
mod _dynosaur_macro_dynsometrait {
use super::*;
trait ErasedSomeTrait {
fn get_iter<'life0, 'dynosaur>(&'life0 mut self)
-> Box<dyn Iterator<Item = u8> + 'dynosaur>
where
'life0: 'dynosaur,
Self: 'dynosaur;
}
impl<DYNOSAUR: SomeTrait> ErasedSomeTrait for DYNOSAUR {
fn get_iter<'life0, 'dynosaur>(&'life0 mut self)
-> Box<dyn Iterator<Item = u8> + 'dynosaur> where
'life0: 'dynosaur, Self: 'dynosaur {
Box::new(<Self as SomeTrait>::get_iter(self))
}
}
#[repr(transparent)]
pub struct DynSomeTrait<'dynosaur_struct> {
ptr: dyn ErasedSomeTrait + 'dynosaur_struct,
}
impl<'dynosaur_struct> SomeTrait for DynSomeTrait<'dynosaur_struct> {
fn get_iter(&mut self) -> impl Iterator<Item = u8> + '_ {
let ret: Box<dyn Iterator<Item = u8> + '_> = self.ptr.get_iter();
let ret: Box<dyn Iterator<Item = u8> + '_> =
unsafe { ::core::mem::transmute(ret) };
ret
}
}
impl<'dynosaur_struct> DynSomeTrait<'dynosaur_struct> {
pub fn boxed(value: impl SomeTrait + 'dynosaur_struct)
-> Box<DynSomeTrait<'dynosaur_struct>> {
let value = Box::new(value);
let value: Box<dyn ErasedSomeTrait + 'dynosaur_struct> = value;
unsafe { ::core::mem::transmute(value) }
}
pub fn from_ref(value: &(impl SomeTrait + 'dynosaur_struct))
-> &DynSomeTrait<'dynosaur_struct> {
let value: &(dyn ErasedSomeTrait + 'dynosaur_struct) = &*value;
unsafe { ::core::mem::transmute(value) }
}
pub fn from_mut(value: &mut (impl SomeTrait + 'dynosaur_struct))
-> &mut DynSomeTrait<'dynosaur_struct> {
let value: &mut (dyn ErasedSomeTrait + 'dynosaur_struct) =
&mut *value;
unsafe { ::core::mem::transmute(value) }
}
}
}
use _dynosaur_macro_dynsometrait::DynSomeTrait;

struct MyImpl([u8; 4]);

impl SomeTrait for MyImpl {
fn get_iter(&mut self) -> impl Iterator<Item = u8> + '_ {
return self.0.into_iter();
}
}

fn main() {
let mut st = DynSomeTrait::boxed(MyImpl([3, 2, 4, 1]));
for x in st.get_iter() {
{ ::std::io::_print(format_args!("{0}\n", x)); };
}
}
124 changes: 103 additions & 21 deletions dynosaur_derive/src/expand.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,16 @@
use crate::lifetime::{used_lifetimes, AddLifetimeToImplTrait, CollectLifetimes};
use crate::receiver::has_self_in_sig;
use crate::where_clauses::where_clause_or_default;
use crate::sig::{is_async, is_rpit};
use crate::where_clauses::{has_where_self_sized, where_clause_or_default};
use proc_macro2::{Span, TokenStream};
use quote::quote;
use std::mem;
use syn::punctuated::Punctuated;
use syn::token::RArrow;
use syn::visit_mut::VisitMut;
use syn::{
parse_quote, parse_quote_spanned, Error, FnArg, GenericParam, Generics, Pat, PatType,
ReturnType, Signature, Token, TraitItemFn, Type, TypeImplTrait,
parse_quote, parse_quote_spanned, Error, FnArg, GenericParam, Generics, ItemTrait, Pat,
PatType, ReturnType, Signature, Token, TraitItemFn, Type, TypeImplTrait, TypeParamBound,
};

/// Expands the signature of each function on the trait, converting async fn into fn with return
Expand Down Expand Up @@ -39,31 +40,18 @@ use syn::{
pub(crate) fn expand_fn_sig(item_trait_generics: &Generics, trait_item_fn: &mut TraitItemFn) {
let sig = &mut trait_item_fn.sig;

if is_async_or_rpit(sig) {
if is_async(sig) {
expand_fn_input(item_trait_generics, sig);
expand_sig_ret_ty_to_pin_box(sig);
} else if is_rpit(sig) {
expand_fn_input(item_trait_generics, sig);
expand_sig_ret_ty_to_box(sig);
}

// Remove default method if any for the erased trait
trait_item_fn.default = None;
}

pub(crate) fn is_async_or_rpit(sig: &Signature) -> bool {
match sig {
Signature {
asyncness: Some(_), ..
} => true,
Signature {
asyncness: None,
output: ReturnType::Type(_, ret),
..
} => {
matches!(**ret, Type::ImplTrait(_))
}
_ => false,
}
}

fn expand_fn_input(item_trait_generics: &Generics, sig: &mut Signature) {
let mut lifetimes = CollectLifetimes::new();
for arg in &mut sig.inputs {
Expand Down Expand Up @@ -148,6 +136,14 @@ pub(crate) fn expand_sig_ret_ty_to_pin_box(sig: &mut Signature) {
sig.output = parse_quote! { #arrow ::core::pin::Pin<Box<dyn #ret + 'dynosaur>> };
}

pub(crate) fn expand_sig_ret_ty_to_box(sig: &mut Signature) {
let (arrow, ret) = expand_arrow_ret_ty(sig);
if let Some(asyncness) = sig.asyncness.take() {
sig.fn_token.span = asyncness.span;
}
sig.output = parse_quote! { #arrow Box<dyn #ret + 'dynosaur> };
}

pub(crate) fn expand_sig_ret_ty_to_rpit(sig: &mut Signature) {
let (arrow, ret) = expand_arrow_ret_ty(sig);
if let Some(asyncness) = sig.asyncness.take() {
Expand Down Expand Up @@ -188,6 +184,84 @@ pub(crate) fn expand_invoke_args(sig: &Signature, ufc: bool) -> Vec<TokenStream>
args
}

pub(crate) fn expand_blanket_impl_fn(
item_trait: &ItemTrait,
trait_item_fn: &mut TraitItemFn,
) -> TokenStream {
let is_async = is_async(&trait_item_fn.sig);
let is_rpit = is_rpit(&trait_item_fn.sig);

expand_fn_sig(&item_trait.generics, trait_item_fn);
let sig = &trait_item_fn.sig;

let trait_ident = &item_trait.ident;
let (_, trait_generics, _) = &item_trait.generics.split_for_impl();
let ident = &sig.ident;
let args = expand_invoke_args(sig, false);
let value = quote! { <Self as #trait_ident #trait_generics>::#ident(#(#args),*) };

let value = if is_async {
quote! {
Box::pin(#value)
}
} else if is_rpit {
quote! {
Box::new(#value)
}
} else {
value
};

quote! {
#sig {
#value
}
}
}

pub(crate) fn expand_dyn_struct_fn(sig: &Signature) -> TokenStream {
if has_where_self_sized(&sig) {
quote! {
#sig {
unreachable!()
}
}
} else {
let ident = &sig.ident;
let args = expand_invoke_args(&sig, true);

if is_async(&sig) {
let ret = expand_ret_ty(&sig);
let mut sig = sig.clone();
expand_sig_ret_ty_to_rpit(&mut sig);

quote! {
#sig {
let fut: ::core::pin::Pin<Box<dyn #ret + '_>> = self.ptr.#ident(#(#args),*);
let fut: ::core::pin::Pin<Box<dyn #ret + 'static>> = unsafe { ::core::mem::transmute(fut) };
fut
}
}
} else if is_rpit(&sig) {
let ret = expand_ret_ty(&sig);

quote! {
#sig {
let ret: Box<dyn #ret + '_> = self.ptr.#ident(#(#args),*);
let ret: Box<dyn #ret + '_> = unsafe { ::core::mem::transmute(ret) };
ret
}
}
} else {
quote! {
#sig {
self.ptr.#ident(#(#args),*)
}
}
}
}
}

fn expand_arrow_ret_ty(sig: &Signature) -> (RArrow, TokenStream) {
match (sig.asyncness.is_some(), &sig.output) {
(true, ReturnType::Default) => {
Expand All @@ -201,7 +275,15 @@ fn expand_arrow_ret_ty(sig: &Signature) -> (RArrow, TokenStream) {
}
(false, ReturnType::Type(arrow, ret)) => {
if let Type::ImplTrait(TypeImplTrait { bounds, .. }) = &**ret {
return (*arrow, quote!(#bounds));
let mut ret_bounds: Punctuated<&TypeParamBound, Token![+]> = Punctuated::new();

for bound in bounds {
if !matches!(bound, TypeParamBound::Lifetime(_)) {
ret_bounds.push(bound);
}
}

return (*arrow, quote! { #ret_bounds });
}
}
_ => {}
Expand Down
Loading

0 comments on commit d1bde14

Please sign in to comment.