Skip to content

Commit

Permalink
Merge pull request #188 from sivizius/skip-variants
Browse files Browse the repository at this point in the history
feat(derive): add variant-attribute `#[arbitrary(skip)]`
  • Loading branch information
fitzgen authored Aug 22, 2024
2 parents 84e6920 + 38be8f8 commit 617ec10
Show file tree
Hide file tree
Showing 5 changed files with 156 additions and 71 deletions.
2 changes: 1 addition & 1 deletion derive/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ rust-version = "1.63.0"
[dependencies]
proc-macro2 = "1.0"
quote = "1.0"
syn = { version = "2", features = ['derive', 'parsing'] }
syn = { version = "2", features = ['derive', 'parsing', 'extra-traits'] }

[lib]
proc_macro = true
172 changes: 103 additions & 69 deletions derive/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,14 @@ use syn::*;

mod container_attributes;
mod field_attributes;
mod variant_attributes;

use container_attributes::ContainerAttributes;
use field_attributes::{determine_field_constructor, FieldConstructor};
use variant_attributes::not_skipped;

static ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
static ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";
const ARBITRARY_ATTRIBUTE_NAME: &str = "arbitrary";
const ARBITRARY_LIFETIME_NAME: &str = "'arbitrary";

#[proc_macro_derive(Arbitrary, attributes(arbitrary))]
pub fn derive_arbitrary(tokens: proc_macro::TokenStream) -> proc_macro::TokenStream {
Expand Down Expand Up @@ -201,81 +204,107 @@ fn gen_arbitrary_method(
})
}

let ident = &input.ident;
let output = match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count)?,
Data::Union(data) => arbitrary_structlike(
&Fields::Named(data.fields.clone()),
ident,
lifetime,
fn arbitrary_variant(
index: u64,
enum_name: &Ident,
variant_name: &Ident,
ctor: TokenStream,
) -> TokenStream {
quote! { #index => #enum_name::#variant_name #ctor }
}

fn arbitrary_enum_method(
recursive_count: &syn::Ident,
unstructured: TokenStream,
variants: &[TokenStream],
) -> impl quote::ToTokens {
let count = variants.len() as u64;
with_recursive_count_guard(
recursive_count,
)?,
Data::Enum(data) => {
let variants: Vec<TokenStream> = data
.variants
.iter()
.enumerate()
.map(|(i, variant)| {
check_variant_attrs(variant)?;
let idx = i as u64;
let variant_name = &variant.ident;
construct(&variant.fields, |_, field| gen_constructor_for_field(field))
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(#unstructured)?) * #count) >> 32 {
#(#variants,)*
_ => unreachable!()
})
.collect::<Result<_>>()?;
},
)
}

let variants_take_rest: Vec<TokenStream> = data
.variants
.iter()
.enumerate()
.map(|(i, variant)| {
let idx = i as u64;
let variant_name = &variant.ident;
construct_take_rest(&variant.fields)
.map(|ctor| quote! { #idx => #ident::#variant_name #ctor })
})
.collect::<Result<_>>()?;
fn arbitrary_enum(
DataEnum { variants, .. }: &DataEnum,
enum_name: &Ident,
lifetime: LifetimeParam,
recursive_count: &syn::Ident,
) -> Result<TokenStream> {
let filtered_variants = variants.iter().filter(not_skipped);

// Check attributes of all variants:
filtered_variants
.clone()
.try_for_each(check_variant_attrs)?;

// From here on, we can assume that the attributes of all variants were checked.
let enumerated_variants = filtered_variants
.enumerate()
.map(|(index, variant)| (index as u64, variant));

// Construct `match`-arms for the `arbitrary` method.
let variants = enumerated_variants
.clone()
.map(|(index, Variant { fields, ident, .. })| {
construct(fields, |_, field| gen_constructor_for_field(field))
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
})
.collect::<Result<Vec<TokenStream>>>()?;

let count = data.variants.len() as u64;
// Construct `match`-arms for the `arbitrary_take_rest` method.
let variants_take_rest = enumerated_variants
.map(|(index, Variant { fields, ident, .. })| {
construct_take_rest(fields)
.map(|ctor| arbitrary_variant(index, enum_name, ident, ctor))
})
.collect::<Result<Vec<TokenStream>>>()?;

// Most of the time, `variants` is not empty (the happy path),
// thus `variants_take_rest` will be used,
// so no need to move this check before constructing `variants_take_rest`.
// If `variants` is empty, this will emit a compiler-error.
(!variants.is_empty())
.then(|| {
// TODO: Improve dealing with `u` vs. `&mut u`.
let arbitrary = arbitrary_enum_method(recursive_count, quote! { u }, &variants);
let arbitrary_take_rest = arbitrary_enum_method(recursive_count, quote! { &mut u }, &variants_take_rest);

let arbitrary = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(u)?) * #count) >> 32 {
#(#variants,)*
_ => unreachable!()
})
},
);

let arbitrary_take_rest = with_recursive_count_guard(
recursive_count,
quote! {
// Use a multiply + shift to generate a ranged random number
// with slight bias. For details, see:
// https://lemire.me/blog/2016/06/30/fast-random-shuffling
Ok(match (u64::from(<u32 as arbitrary::Arbitrary>::arbitrary(&mut u)?) * #count) >> 32 {
#(#variants_take_rest,)*
_ => unreachable!()
})
},
);
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary
}

quote! {
fn arbitrary(u: &mut arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary
fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary_take_rest
}
}
})
.ok_or_else(|| Error::new_spanned(
enum_name,
"Enum must have at least one variant, that is not skipped"
))
}

fn arbitrary_take_rest(mut u: arbitrary::Unstructured<#lifetime>) -> arbitrary::Result<Self> {
#arbitrary_take_rest
}
}
}
};
Ok(output)
let ident = &input.ident;
match &input.data {
Data::Struct(data) => arbitrary_structlike(&data.fields, ident, lifetime, recursive_count),
Data::Union(data) => arbitrary_structlike(
&Fields::Named(data.fields.clone()),
ident,
lifetime,
recursive_count,
),
Data::Enum(data) => arbitrary_enum(data, ident, lifetime, recursive_count),
}
}

fn construct(
Expand Down Expand Up @@ -375,7 +404,12 @@ fn gen_size_hint_method(input: &DeriveInput) -> Result<TokenStream> {
Data::Enum(data) => data
.variants
.iter()
.map(|v| size_hint_fields(&v.fields))
.filter(not_skipped)
.map(|Variant { fields, .. }| {
// The attributes of all variants are checked in `gen_arbitrary_method` above
// and can therefore assume that they are valid.
size_hint_fields(fields)
})
.collect::<Result<Vec<TokenStream>>>()
.map(|variants| {
quote! {
Expand Down
21 changes: 21 additions & 0 deletions derive/src/variant_attributes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
use crate::ARBITRARY_ATTRIBUTE_NAME;
use syn::*;

pub fn not_skipped(variant: &&Variant) -> bool {
!should_skip(variant)
}

fn should_skip(Variant { attrs, .. }: &Variant) -> bool {
attrs
.iter()
.filter_map(|attr| {
attr.path()
.is_ident(ARBITRARY_ATTRIBUTE_NAME)
.then(|| attr.parse_args::<Meta>())
.and_then(Result::ok)
})
.any(|meta| match meta {
Meta::Path(path) => path.is_ident("skip"),
_ => false,
})
}
8 changes: 7 additions & 1 deletion examples/derive_enum.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ use arbitrary::{Arbitrary, Unstructured};
enum MyEnum {
UnitVariant,
TupleVariant(bool, u32),
StructVariant { x: i8, y: (u8, i32) },
StructVariant {
x: i8,
y: (u8, i32),
},

#[arbitrary(skip)]
SkippedVariant(usize),
}

fn main() {
Expand Down
24 changes: 24 additions & 0 deletions tests/derive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,30 @@ fn derive_enum() {
assert_eq!((4, Some(17)), <MyEnum as Arbitrary>::size_hint(0));
}

// This should result in a compiler-error:
// #[derive(Arbitrary, Debug)]
// enum Never {
// #[arbitrary(skip)]
// Nope,
// }

#[derive(Arbitrary, Debug)]
enum SkipVariant {
Always,
#[arbitrary(skip)]
Never,
}

#[test]
fn test_skip_variant() {
(0..=u8::MAX).for_each(|byte| {
let buffer = [byte];
let unstructured = Unstructured::new(&buffer);
let skip_variant = SkipVariant::arbitrary_take_rest(unstructured).unwrap();
assert!(!matches!(skip_variant, SkipVariant::Never));
})
}

#[derive(Arbitrary, Debug)]
enum RecursiveTree {
Leaf,
Expand Down

0 comments on commit 617ec10

Please sign in to comment.