Skip to content

Commit

Permalink
allow using tuple structs in the world query derive
Browse files Browse the repository at this point in the history
  • Loading branch information
joseph-gio committed Mar 17, 2023
1 parent bca4b36 commit 543415c
Showing 1 changed file with 82 additions and 40 deletions.
122 changes: 82 additions & 40 deletions crates/bevy_ecs/macros/src/fetch.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use proc_macro::TokenStream;
use proc_macro2::{Ident, Span};
use quote::{quote, ToTokens};
use quote::{format_ident, quote, ToTokens};
use syn::{
parse::{Parse, ParseStream},
parse_quote,
punctuated::Punctuated,
Attribute, Data, DataStruct, DeriveInput, Field, Fields,
Attribute, Data, DataStruct, DeriveInput, Field, Index,
};

use crate::bevy_ecs_path;
Expand Down Expand Up @@ -112,37 +112,59 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {

let state_struct_name = Ident::new(&format!("{struct_name}State"), Span::call_site());

let fields = match &ast.data {
Data::Struct(DataStruct {
fields: Fields::Named(fields),
..
}) => &fields.named,
_ => panic!("Expected a struct with named fields"),
let Data::Struct(DataStruct { fields, .. }) = &ast.data else {
return syn::Error::new(
Span::call_site(),
"#[derive(WorldQuery)]` only supports structs",
)
.into_compile_error()
.into()
};
if matches!(fields, syn::Fields::Unit) {
return syn::Error::new(
Span::call_site(),
"#[derive(WorldQuery)]` does not support unit structs",
)
.into_compile_error()
.into();
}

let mut ignored_field_attrs = Vec::new();
let mut ignored_field_visibilities = Vec::new();
let mut ignored_field_idents = Vec::new();
let mut ignored_named_field_idents = Vec::new();
let mut ignored_field_types = Vec::new();
let mut field_attrs = Vec::new();
let mut field_visibilities = Vec::new();
let mut field_idents = Vec::new();
let mut named_field_idents = Vec::new();
let mut field_types = Vec::new();
let mut read_only_field_types = Vec::new();

for field in fields {
for (i, field) in fields.iter().enumerate() {
let WorldQueryFieldInfo { is_ignored, attrs } = read_world_query_field_info(field);

let field_ident = field.ident.as_ref().unwrap().clone();
let named_field_ident = field
.ident
.as_ref()
.cloned()
.unwrap_or_else(|| format_ident!("f{i}"));
let i = Index::from(i);
let field_ident = field
.ident
.as_ref()
.map_or(quote! { #i }, |i| quote! { #i });
if is_ignored {
ignored_field_attrs.push(attrs);
ignored_field_visibilities.push(field.vis.clone());
ignored_field_idents.push(field_ident.clone());
ignored_field_idents.push(field_ident);
ignored_named_field_idents.push(named_field_ident);
ignored_field_types.push(field.ty.clone());
} else {
field_attrs.push(attrs);
field_visibilities.push(field.vis.clone());
field_idents.push(field_ident.clone());
field_idents.push(field_ident);
named_field_idents.push(named_field_ident);
let field_ty = field.ty.clone();
field_types.push(quote!(#field_ty));
read_only_field_types.push(quote!(<#field_ty as #path::query::WorldQuery>::ReadOnly));
Expand Down Expand Up @@ -176,16 +198,36 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
&field_types
};

let item_struct = quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
#(#(#ignored_field_attrs)* #ignored_field_visibilities #ignored_field_idents: #ignored_field_types,)*
}
let item_struct = match fields {
syn::Fields::Named(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#(#field_attrs)* #field_visibilities #field_idents: <#field_types as #path::query::WorldQuery>::Item<'__w>,)*
#(#(#ignored_field_attrs)* #ignored_field_visibilities #ignored_field_idents: #ignored_field_types,)*
}
},
syn::Fields::Unnamed(_) => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world(
#( #field_visibilities <#field_types as #path::query::WorldQuery>::Item<'__w>, )*
);
},
syn::Fields::Unit => quote! {
#derive_macro_call
#[doc = "Automatically generated [`WorldQuery`] item type for [`"]
#[doc = stringify!(#struct_name)]
#[doc = "`], returned when iterating over query results."]
#[automatically_derived]
#visibility struct #item_struct_name #user_impl_generics_with_world #user_where_clauses_with_world;
},
};

let query_impl = quote! {
Expand All @@ -195,8 +237,8 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#[doc = "`], used to define the world data accessed by this query."]
#[automatically_derived]
#visibility struct #fetch_struct_name #user_impl_generics_with_world #user_where_clauses_with_world {
#(#field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#(#ignored_field_idents: #ignored_field_types,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::Fetch<'__w>,)*
#(#ignored_named_field_idents: #ignored_field_types,)*
}

// SAFETY: `update_component_access` and `update_archetype_component_access` are called on every field
Expand Down Expand Up @@ -228,15 +270,15 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_this_run: #path::component::Tick,
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(#field_idents:
#(#named_field_idents:
<#field_types>::init_fetch(
_world,
&state.#field_idents,
&state.#named_field_idents,
_last_run,
_this_run,
),
)*
#(#ignored_field_idents: Default::default(),)*
#(#ignored_named_field_idents: Default::default(),)*
}
}

Expand All @@ -245,10 +287,10 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
) -> <Self as #path::query::WorldQuery>::Fetch<'__w> {
#fetch_struct_name {
#(
#field_idents: <#field_types>::clone_fetch(& _fetch. #field_idents),
#named_field_idents: <#field_types>::clone_fetch(& _fetch. #named_field_idents),
)*
#(
#ignored_field_idents: Default::default(),
#ignored_named_field_idents: Default::default(),
)*
}
}
Expand All @@ -265,7 +307,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_archetype: &'__w #path::archetype::Archetype,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_archetype(&mut _fetch.#field_idents, &_state.#field_idents, _archetype, _table);)*
#(<#field_types>::set_archetype(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _archetype, _table);)*
}

/// SAFETY: we call `set_table` for each member that implements `Fetch`
Expand All @@ -275,7 +317,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_state: &Self::State,
_table: &'__w #path::storage::Table
) {
#(<#field_types>::set_table(&mut _fetch.#field_idents, &_state.#field_idents, _table);)*
#(<#field_types>::set_table(&mut _fetch.#named_field_idents, &_state.#named_field_idents, _table);)*
}

/// SAFETY: we call `fetch` for each member that implements `Fetch`.
Expand All @@ -286,7 +328,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_table_row: #path::storage::TableRow,
) -> <Self as #path::query::WorldQuery>::Item<'__w> {
Self::Item {
#(#field_idents: <#field_types>::fetch(&mut _fetch.#field_idents, _entity, _table_row),)*
#(#field_idents: <#field_types>::fetch(&mut _fetch.#named_field_idents, _entity, _table_row),)*
#(#ignored_field_idents: Default::default(),)*
}
}
Expand All @@ -298,11 +340,11 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_entity: #path::entity::Entity,
_table_row: #path::storage::TableRow,
) -> bool {
true #(&& <#field_types>::filter_fetch(&mut _fetch.#field_idents, _entity, _table_row))*
true #(&& <#field_types>::filter_fetch(&mut _fetch.#named_field_idents, _entity, _table_row))*
}

fn update_component_access(state: &Self::State, _access: &mut #path::query::FilteredAccess<#path::component::ComponentId>) {
#( <#field_types>::update_component_access(&state.#field_idents, _access); )*
#( <#field_types>::update_component_access(&state.#named_field_idents, _access); )*
}

fn update_archetype_component_access(
Expand All @@ -311,19 +353,19 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
_access: &mut #path::query::Access<#path::archetype::ArchetypeComponentId>
) {
#(
<#field_types>::update_archetype_component_access(&state.#field_idents, _archetype, _access);
<#field_types>::update_archetype_component_access(&state.#named_field_idents, _archetype, _access);
)*
}

fn init_state(world: &mut #path::world::World) -> #state_struct_name #user_ty_generics {
#state_struct_name {
#(#field_idents: <#field_types>::init_state(world),)*
#(#named_field_idents: <#field_types>::init_state(world),)*
#(#ignored_field_idents: Default::default(),)*
}
}

fn matches_component_set(state: &Self::State, _set_contains_id: &impl Fn(#path::component::ComponentId) -> bool) -> bool {
true #(&& <#field_types>::matches_component_set(&state.#field_idents, _set_contains_id))*
true #(&& <#field_types>::matches_component_set(&state.#named_field_idents, _set_contains_id))*
}
}
};
Expand All @@ -339,7 +381,7 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#[doc = "`]."]
#[automatically_derived]
#visibility struct #read_only_struct_name #user_impl_generics #user_where_clauses {
#( #field_idents: #read_only_field_types, )*
#( #named_field_idents: #read_only_field_types, )*
#(#(#ignored_field_attrs)* #ignored_field_visibilities #ignored_field_idents: #ignored_field_types,)*
}

Expand Down Expand Up @@ -386,8 +428,8 @@ pub fn derive_world_query_impl(ast: DeriveInput) -> TokenStream {
#[doc = "`], used for caching."]
#[automatically_derived]
#visibility struct #state_struct_name #user_impl_generics #user_where_clauses {
#(#field_idents: <#field_types as #path::query::WorldQuery>::State,)*
#(#ignored_field_idents: #ignored_field_types,)*
#(#named_field_idents: <#field_types as #path::query::WorldQuery>::State,)*
#(#ignored_named_field_idents: #ignored_field_types,)*
}

#mutable_impl
Expand Down

0 comments on commit 543415c

Please sign in to comment.