From c0c1ded58d0351dcd380c2c8f8f6b5cdab28f4ac Mon Sep 17 00:00:00 2001 From: Julien Cretin Date: Tue, 14 May 2024 15:56:28 +0200 Subject: [PATCH] Add #[postcard(bound = "...")] attribute for derive(Schema) --- source/postcard-derive/src/lib.rs | 9 +++- source/postcard-derive/src/schema.rs | 62 +++++++++++++++++++++++----- source/postcard/tests/schema.rs | 34 +++++++++++++++ 3 files changed, 92 insertions(+), 13 deletions(-) diff --git a/source/postcard-derive/src/lib.rs b/source/postcard-derive/src/lib.rs index 82b4645..57084e3 100644 --- a/source/postcard-derive/src/lib.rs +++ b/source/postcard-derive/src/lib.rs @@ -1,3 +1,5 @@ +use syn::{parse_macro_input, DeriveInput}; + mod max_size; mod schema; @@ -8,7 +10,10 @@ pub fn derive_max_size(item: proc_macro::TokenStream) -> proc_macro::TokenStream } /// Derive the `postcard::Schema` trait for a struct or enum. -#[proc_macro_derive(Schema)] +#[proc_macro_derive(Schema, attributes(postcard))] pub fn derive_schema(item: proc_macro::TokenStream) -> proc_macro::TokenStream { - schema::do_derive_schema(item) + let input = parse_macro_input!(item as DeriveInput); + schema::do_derive_schema(input) + .unwrap_or_else(syn::Error::into_compile_error) + .into() } diff --git a/source/postcard-derive/src/schema.rs b/source/postcard-derive/src/schema.rs index 19e8a6f..db038d3 100644 --- a/source/postcard-derive/src/schema.rs +++ b/source/postcard-derive/src/schema.rs @@ -1,22 +1,22 @@ use proc_macro2::{Span, TokenStream}; use quote::{quote, quote_spanned}; use syn::{ - parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam, - Generics, + parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, Data, DeriveInput, Error, + Fields, GenericParam, Generics, Token, }; -pub fn do_derive_schema(item: proc_macro::TokenStream) -> proc_macro::TokenStream { - let input = parse_macro_input!(item as DeriveInput); - +pub fn do_derive_schema(input: DeriveInput) -> Result { let span = input.span(); let name = input.ident; // Add a bound `T: Schema` to every type parameter T. - let generics = add_trait_bounds(input.generics); + let generics = match find_bounds(&input.attrs, &input.generics)? { + Some(x) => x, + None => add_trait_bounds(input.generics), + }; let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - let ty = generate_type(&input.data, span, name.to_string()) - .unwrap_or_else(syn::Error::into_compile_error); + let ty = generate_type(&input.data, span, name.to_string())?; let expanded = quote! { impl #impl_generics ::postcard::experimental::schema::Schema for #name #ty_generics #where_clause { @@ -24,10 +24,10 @@ pub fn do_derive_schema(item: proc_macro::TokenStream) -> proc_macro::TokenStrea } }; - expanded.into() + Ok(expanded) } -fn generate_type(data: &Data, span: Span, name: String) -> Result { +fn generate_type(data: &Data, span: Span, name: String) -> Result { let ty = match data { Data::Struct(data) => generate_struct(&data.fields), Data::Enum(data) => { @@ -41,7 +41,7 @@ fn generate_type(data: &Data, span: Span, name: String) -> Result { - return Err(syn::Error::new( + return Err(Error::new( span, "unions are not supported by `postcard::experimental::schema`", )) @@ -121,3 +121,43 @@ fn add_trait_bounds(mut generics: Generics) -> Generics { } generics } + +fn find_bounds(attrs: &[Attribute], generics: &Generics) -> Result, Error> { + let mut result = None; + for attr in attrs { + if !attr.path.is_ident("postcard") { + continue; + } + let span = attr.span(); + let meta = match attr.parse_meta()? { + syn::Meta::List(x) => x, + _ => return Err(Error::new(span, "expected #[postcard(...)]")), + }; + for meta in meta.nested { + let meta = match meta { + syn::NestedMeta::Meta(x) => x, + _ => return Err(Error::new(span, "expected #[postcard($meta)]")), + }; + let meta = match meta { + syn::Meta::NameValue(x) => x, + _ => return Err(Error::new(span, "expected #[postcard($path = $lit)]")), + }; + if !meta.path.is_ident("bound") { + return Err(Error::new(span, "expected #[postcard(bound = $lit)]")); + } + if result.is_some() { + return Err(Error::new(span, "duplicate #[postcard(bound = \"...\")]")); + } + let bound = match meta.lit { + syn::Lit::Str(x) => x, + _ => return Err(Error::new(span, "expected #[postcard(bound = \"...\")]")), + }; + let bound = + bound.parse_with(Punctuated::::parse_terminated)?; + let mut generics = generics.clone(); + generics.make_where_clause().predicates.extend(bound); + result = Some(generics); + } + } + Ok(result) +} diff --git a/source/postcard/tests/schema.rs b/source/postcard/tests/schema.rs index 244560a..24cd30f 100644 --- a/source/postcard/tests/schema.rs +++ b/source/postcard/tests/schema.rs @@ -51,6 +51,26 @@ struct Slice<'a> { x: &'a [u8], } +#[allow(unused)] +#[derive(Schema)] +#[postcard(bound = "")] // doesn't compile without this +struct Bound { + x: F::Out, +} + +mod bound { + use super::*; + + pub trait Fun { + type Out: Schema; + } + + pub enum Id {} + impl Fun for Id { + type Out = In; + } +} + #[test] fn test_enum_serialize() { assert_eq!( @@ -161,3 +181,17 @@ fn test_slice_serialize() { Slice::SCHEMA ); } + +#[test] +fn test_bound_serialize() { + assert_eq!( + &NamedType { + name: "Bound", + ty: &SdmTy::Struct(&[&NamedValue { + name: "x", + ty: &U8_SCHEMA + }]), + }, + Bound::::SCHEMA, + ); +}