Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 49 additions & 31 deletions crates/oxc_ast_macros/src/ast.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use proc_macro2::{TokenStream, TokenTree};
use quote::quote;
use quote::{quote, quote_spanned};
use syn::{
Attribute, Fields, FieldsNamed, Ident, Item, ItemEnum, ItemStruct, parse_quote,
punctuated::Punctuated, token::Comma,
punctuated::Punctuated, spanned::Spanned, token::Comma,
};

use crate::generated::{derived_traits::get_trait_crate_and_generics, structs::STRUCTS};
Expand Down Expand Up @@ -47,41 +47,48 @@ pub struct StructDetails {
fn modify_struct(item: &mut ItemStruct, args: TokenStream) -> TokenStream {
let assertions = assert_generated_derives(&item.attrs);

reorder_struct_fields(item, args);
let reorder_result = reorder_struct_fields(item, args);
let error = reorder_result.err().map(|message| compile_error(&item.ident, message));

quote! {
#[repr(C)]
#[derive(::oxc_ast_macros::Ast)]
#item
#error
#assertions
}
}

/// Re-order struct fields, depending on instructions in `STRUCTS` (which is codegen-ed).
///
/// Mutates `item` in place, re-ordering its fields.
fn reorder_struct_fields(item: &mut ItemStruct, args: TokenStream) {
fn reorder_struct_fields(item: &mut ItemStruct, args: TokenStream) -> Result<(), &'static str> {
// Skip foreign types
if let Some(TokenTree::Ident(ident)) = args.into_iter().next() {
if ident == "foreign" {
return;
return Ok(());
}
}

// Get struct data. Exit if no fields need re-ordering.
// Get struct data
let struct_name = item.ident.to_string();
let Some(field_order) = STRUCTS[&struct_name].field_order else {
return;
let Some(struct_details) = STRUCTS.get(&struct_name) else {
return Err("Struct is unknown. Run `just ast` to re-run the codegen.");
};

// Exit if fields don't need re-ordering
let Some(field_order) = struct_details.field_order else {
return Ok(());
};

// Re-order fields.
// `field_order` contains indexes of fields in the order they should be.
let Fields::Named(FieldsNamed { named, .. }) = &mut item.fields else { unreachable!() };

assert!(
named.len() == field_order.len(),
"Wrong number of fields for `{struct_name}` in `STRUCTS`"
);
let named = match &mut item.fields {
Fields::Named(FieldsNamed { named, .. }) if named.len() == field_order.len() => named,
_ => {
return Err("Struct has been altered. Run `just ast` to re-run the codegen.");
}
};

// Create 2 sets of fields.
// 1st set are the fields in original order, each prefixed with `#[cfg(doc)]`.
Expand All @@ -98,6 +105,8 @@ fn reorder_struct_fields(item: &mut ItemStruct, args: TokenStream) {
pair.value_mut().attrs.insert(0, parse_quote!( #[cfg(not(doc))]));
pair
}));

Ok(())
}

/// Generate assertions that traits used in `#[generate_derive]` are in scope.
Expand All @@ -115,31 +124,40 @@ fn reorder_struct_fields(item: &mut ItemStruct, args: TokenStream) {
///
/// If `GetSpan` is not in scope, or it is not the correct `oxc_span::GetSpan`,
/// this will raise a compilation error.
///
/// If any errors e.g. cannot parse `#[generate_derive]`, or unknown traits, just skip them.
/// It is the responsibility of `oxc_ast_tools` to raise errors for those.
fn assert_generated_derives(attrs: &[Attribute]) -> TokenStream {
// We don't care here if a trait is derived multiple times.
// It is the responsibility of `oxc_ast_tools` to raise errors for those.
let assertions = attrs
.iter()
.filter(|attr| attr.path().is_ident("generate_derive"))
.flat_map(parse_attr)
.map(|trait_ident| {
let mut assertions = quote!();
for attr in attrs {
if !attr.path().is_ident("generate_derive") {
continue;
}

let Ok(parsed) = attr.parse_args_with(Punctuated::<Ident, Comma>::parse_terminated) else {
continue;
};

for trait_ident in parsed {
let trait_name = trait_ident.to_string();
let Some((trait_path, generics)) = get_trait_crate_and_generics(&trait_name) else {
panic!("Invalid derive trait(generate_derive): {trait_name}");
continue;
};

// These are wrapped in a scope to avoid the need for unique identifiers
quote! {{
assertions.extend(quote! {{
trait AssertionTrait: #trait_path #generics {}
impl<T: #trait_ident #generics> AssertionTrait for T {}
}}
});
quote!( const _: () = { #(#assertions)* }; )
}});
}
}

quote! {
const _: () = { #assertions };
}
}

#[inline]
fn parse_attr(attr: &Attribute) -> impl Iterator<Item = Ident> + use<> {
attr.parse_args_with(Punctuated::<Ident, Comma>::parse_terminated)
.expect("`#[generate_derive]` only accepts traits as single segment paths. Found an invalid argument.")
.into_iter()
/// Generate a `compile_error!` macro invocation with the given message, and the span of `spanned`.
fn compile_error<S: Spanned>(spanned: &S, message: &str) -> TokenStream {
quote_spanned! { spanned.span() => compile_error!(#message); }
}
Loading