diff --git a/derive/Cargo.toml b/derive/Cargo.toml index 4bae30b..9cbf384 100644 --- a/derive/Cargo.toml +++ b/derive/Cargo.toml @@ -21,7 +21,7 @@ rust-version = "1.63.0" [dependencies] proc-macro2 = "1.0" quote = "1.0" -syn = { version = "2", features = ['derive', 'parsing'] } +syn = { version = "2", features = ['derive', 'parsing', 'extra-traits'] } [lib] proc_macro = true diff --git a/derive/src/lib.rs b/derive/src/lib.rs index 2223070..33b6417 100644 --- a/derive/src/lib.rs +++ b/derive/src/lib.rs @@ -6,11 +6,14 @@ use syn::*; mod container_attributes; mod field_attributes; +mod variant_attributes; + use container_attributes::ContainerAttributes; use field_attributes::{determine_field_constructor, FieldConstructor}; +use variant_attributes::not_skipped; -static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary"; -static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary"; +const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary"; +const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary"; #[proc_macro_derive(Arbitrary, attributes(arbitrary))] pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream { @@ -201,81 +204,107 @@ fn gen_arbitrary_method( }) } - let ident = &input.ident; - let output = match &input.data { - Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?, - Data::Union(data) => arbitrary_structlike( - &Fields::Named(data.fields.clone()), - ident, - lifetime, + fn arbitrary_variant( + index: u64, + enum_name: &Ident, + variant_name: &Ident, + ctor: TokenStream, + ) -> TokenStream { + quote! { #index => #enum_name::#variant_name #ctor } + } + + fn arbitrary_enum_method( + recursive_count: &syn::Ident, + unstructured: TokenStream, + variants: &[TokenStream], + ) -> impl quote::ToTokens { + let count = variants.len() as u64; + with_recursive_count_guard( recursive_count, - )?, - Data::Enum(data) => { - let variants: Vec = data - .variants - .iter() - .enumerate() - .map(|(i, variant)| { - check_variant_attrs(variant)?; - let idx = i as u64; - let variant_name = &variant.ident; - construct(&variant.fields, |_, field| gen_constructor_for_field(field)) - .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) + quote! { + // Use a multiply + shift to generate a ranged random number + // with slight bias. For details, see: + // https://lemire.me/blog/2016/06/30/fast-random-shuffling + Ok(match (u64::from(::arbitrary(#unstructured)?) * #count) >> 32 { + #(#variants,)* + _ => unreachable!() }) - .collect::>()?; + }, + ) + } - let variants_take_rest: Vec = data - .variants - .iter() - .enumerate() - .map(|(i, variant)| { - let idx = i as u64; - let variant_name = &variant.ident; - construct_take_rest(&variant.fields) - .map(|ctor| quote! { #idx => #ident::#variant_name #ctor }) - }) - .collect::>()?; + fn arbitrary_enum( + DataEnum { variants, .. }: &DataEnum, + enum_name: &Ident, + lifetime: LifetimeParam, + recursive_count: &syn::Ident, + ) -> Result { + let filtered_variants = variants.iter().filter(not_skipped); + + // Check attributes of all variants: + filtered_variants + .clone() + .try_for_each(check_variant_attrs)?; + + // From here on, we can assume that the attributes of all variants were checked. + let enumerated_variants = filtered_variants + .enumerate() + .map(|(index, variant)| (index as u64, variant)); + + // Construct `match`-arms for the `arbitrary` method. + let variants = enumerated_variants + .clone() + .map(|(index, Variant { fields, ident, .. })| { + construct(fields, |_, field| gen_constructor_for_field(field)) + .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor)) + }) + .collect::>>()?; - let count = data.variants.len() as u64; + // Construct `match`-arms for the `arbitrary_take_rest` method. + let variants_take_rest = enumerated_variants + .map(|(index, Variant { fields, ident, .. })| { + construct_take_rest(fields) + .map(|ctor| arbitrary_variant(index, enum_name, ident, ctor)) + }) + .collect::>>()?; + + // Most of the time, `variants` is not empty (the happy path), + // thus `variants_take_rest` will be used, + // so no need to move this check before constructing `variants_take_rest`. + // If `variants` is empty, this will emit a compiler-error. + (!variants.is_empty()) + .then(|| { + // TODO: Improve dealing with `u` vs. `&mut u`. + let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants); + let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest); - let arbitrary = with_recursive_count_guard( - recursive_count, - quote! { - // Use a multiply + shift to generate a ranged random number - // with slight bias. For details, see: - // https://lemire.me/blog/2016/06/30/fast-random-shuffling - Ok(match (u64::from(::arbitrary(u)?) * #count) >> 32 { - #(#variants,)* - _ => unreachable!() - }) - }, - ); - - let arbitrary_take_rest = with_recursive_count_guard( - recursive_count, quote! { - // Use a multiply + shift to generate a ranged random number - // with slight bias. For details, see: - // https://lemire.me/blog/2016/06/30/fast-random-shuffling - Ok(match (u64::from(::arbitrary(&mut u)?) * #count) >> 32 { - #(#variants_take_rest,)* - _ => unreachable!() - }) - }, - ); + fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { + #arbitrary + } - quote! { - fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { - #arbitrary + fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { + #arbitrary_take_rest + } } + }) + .ok_or_else(|| Error::new_spanned( + enum_name, + "Enum must have at least one variant, that is not skipped" + )) + } - fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result { - #arbitrary_take_rest - } - } - } - }; - Ok(output) + let ident = &input.ident; + match &input.data { + Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count), + Data::Union(data) => arbitrary_structlike( + &Fields::Named(data.fields.clone()), + ident, + lifetime, + recursive_count, + ), + Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count), + } } fn construct( @@ -375,7 +404,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result { Data::Enum(data) => data .variants .iter() - .map(|v| size_hint_fields(&v.fields)) + .filter(not_skipped) + .map(|Variant { fields, .. }| { + // The attributes of all variants are checked in `gen_arbitrary_method` above + // and can therefore assume that they are valid. + size_hint_fields(fields) + }) .collect::>>() .map(|variants| { quote! { diff --git a/derive/src/variant_attributes.rs b/derive/src/variant_attributes.rs new file mode 100644 index 0000000..39957df --- /dev/null +++ b/derive/src/variant_attributes.rs @@ -0,0 +1,21 @@ +use crate::ARBITRARY_ATTRIBUTE_NAME; +use syn::*; + +pub fn not_skipped(variant: &&Variant) -> bool { + !should_skip(variant) +} + +fn should_skip(Variant { attrs, .. }: &Variant) -> bool { + attrs + .iter() + .filter_map(|attr| { + attr.path() + .is_ident(ARBITRARY_ATTRIBUTE_NAME) + .then(|| attr.parse_args::()) + .and_then(Result::ok) + }) + .any(|meta| match meta { + Meta::Path(path) => path.is_ident("skip"), + _ => false, + }) +} diff --git a/examples/derive_enum.rs b/examples/derive_enum.rs index c4fc9c9..a44a97c 100644 --- a/examples/derive_enum.rs +++ b/examples/derive_enum.rs @@ -12,7 +12,13 @@ use arbitrary::{Arbitrary, Unstructured}; enum MyEnum { UnitVariant, TupleVariant(bool, u32), - StructVariant { x: i8, y: (u8, i32) }, + StructVariant { + x: i8, + y: (u8, i32), + }, + + #[arbitrary(skip)] + SkippedVariant(usize), } fn main() { diff --git a/tests/derive.rs b/tests/derive.rs index a84e4dc..b73da11 100644 --- a/tests/derive.rs +++ b/tests/derive.rs @@ -116,6 +116,30 @@ fn derive_enum() { assert_eq!((4, Some(17)), ::size_hint(0)); } +// This should result in a compiler-error: +// #[derive(Arbitrary, Debug)] +// enum Never { +// #[arbitrary(skip)] +// Nope, +// } + +#[derive(Arbitrary, Debug)] +enum SkipVariant { + Always, + #[arbitrary(skip)] + Never, +} + +#[test] +fn test_skip_variant() { + (0..=u8::MAX).for_each(|byte| { + let buffer = [byte]; + let unstructured = Unstructured::new(&buffer); + let skip_variant = SkipVariant::arbitrary_take_rest(unstructured).unwrap(); + assert!(!matches!(skip_variant, SkipVariant::Never)); + }) +} + #[derive(Arbitrary, Debug)] enum RecursiveTree { Leaf,