Skip to content

Commit

Permalink
Add #[postcard(bound = "...")] attribute for derive(Schema)
Browse files Browse the repository at this point in the history
  • Loading branch information
ia0 committed May 14, 2024
1 parent 93d4ba2 commit c0c1ded
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 13 deletions.
9 changes: 7 additions & 2 deletions source/postcard-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
use syn::{parse_macro_input, DeriveInput};

mod max_size;
mod schema;

Expand All @@ -8,7 +10,10 @@ pub fn derive_max_size(item: proc_macro::TokenStream) -> proc_macro::TokenStream
}

/// Derive the `postcard::Schema` trait for a struct or enum.
#[proc_macro_derive(Schema)]
#[proc_macro_derive(Schema, attributes(postcard))]
pub fn derive_schema(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
schema::do_derive_schema(item)
let input = parse_macro_input!(item as DeriveInput);
schema::do_derive_schema(input)
.unwrap_or_else(syn::Error::into_compile_error)
.into()
}
62 changes: 51 additions & 11 deletions source/postcard-derive/src/schema.rs
Original file line number Diff line number Diff line change
@@ -1,33 +1,33 @@
use proc_macro2::{Span, TokenStream};
use quote::{quote, quote_spanned};
use syn::{
parse_macro_input, parse_quote, spanned::Spanned, Data, DeriveInput, Fields, GenericParam,
Generics,
parse_quote, punctuated::Punctuated, spanned::Spanned, Attribute, Data, DeriveInput, Error,
Fields, GenericParam, Generics, Token,
};

pub fn do_derive_schema(item: proc_macro::TokenStream) -> proc_macro::TokenStream {
let input = parse_macro_input!(item as DeriveInput);

pub fn do_derive_schema(input: DeriveInput) -> Result<TokenStream, Error> {
let span = input.span();
let name = input.ident;

// Add a bound `T: Schema` to every type parameter T.
let generics = add_trait_bounds(input.generics);
let generics = match find_bounds(&input.attrs, &input.generics)? {
Some(x) => x,
None => add_trait_bounds(input.generics),
};
let (impl_generics, ty_generics, where_clause) = generics.split_for_impl();

let ty = generate_type(&input.data, span, name.to_string())
.unwrap_or_else(syn::Error::into_compile_error);
let ty = generate_type(&input.data, span, name.to_string())?;

let expanded = quote! {
impl #impl_generics ::postcard::experimental::schema::Schema for #name #ty_generics #where_clause {
const SCHEMA: &'static ::postcard::experimental::schema::NamedType = #ty;
}
};

expanded.into()
Ok(expanded)
}

fn generate_type(data: &Data, span: Span, name: String) -> Result<TokenStream, syn::Error> {
fn generate_type(data: &Data, span: Span, name: String) -> Result<TokenStream, Error> {
let ty = match data {
Data::Struct(data) => generate_struct(&data.fields),
Data::Enum(data) => {
Expand All @@ -41,7 +41,7 @@ fn generate_type(data: &Data, span: Span, name: String) -> Result<TokenStream, s
}
}
Data::Union(_) => {
return Err(syn::Error::new(
return Err(Error::new(
span,
"unions are not supported by `postcard::experimental::schema`",
))
Expand Down Expand Up @@ -121,3 +121,43 @@ fn add_trait_bounds(mut generics: Generics) -> Generics {
}
generics
}

fn find_bounds(attrs: &[Attribute], generics: &Generics) -> Result<Option<Generics>, Error> {
let mut result = None;
for attr in attrs {
if !attr.path.is_ident("postcard") {
continue;
}
let span = attr.span();
let meta = match attr.parse_meta()? {
syn::Meta::List(x) => x,
_ => return Err(Error::new(span, "expected #[postcard(...)]")),
};
for meta in meta.nested {
let meta = match meta {
syn::NestedMeta::Meta(x) => x,
_ => return Err(Error::new(span, "expected #[postcard($meta)]")),
};
let meta = match meta {
syn::Meta::NameValue(x) => x,
_ => return Err(Error::new(span, "expected #[postcard($path = $lit)]")),
};
if !meta.path.is_ident("bound") {
return Err(Error::new(span, "expected #[postcard(bound = $lit)]"));
}
if result.is_some() {
return Err(Error::new(span, "duplicate #[postcard(bound = \"...\")]"));
}
let bound = match meta.lit {
syn::Lit::Str(x) => x,
_ => return Err(Error::new(span, "expected #[postcard(bound = \"...\")]")),
};
let bound =
bound.parse_with(Punctuated::<syn::WherePredicate, Token![,]>::parse_terminated)?;
let mut generics = generics.clone();
generics.make_where_clause().predicates.extend(bound);
result = Some(generics);
}
}
Ok(result)
}
34 changes: 34 additions & 0 deletions source/postcard/tests/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,26 @@ struct Slice<'a> {
x: &'a [u8],
}

#[allow(unused)]
#[derive(Schema)]
#[postcard(bound = "")] // doesn't compile without this
struct Bound<F: bound::Fun> {
x: F::Out<u8>,
}

mod bound {
use super::*;

pub trait Fun {
type Out<In: Schema>: Schema;
}

pub enum Id {}
impl Fun for Id {
type Out<In: Schema> = In;
}
}

#[test]
fn test_enum_serialize() {
assert_eq!(
Expand Down Expand Up @@ -161,3 +181,17 @@ fn test_slice_serialize() {
Slice::SCHEMA
);
}

#[test]
fn test_bound_serialize() {
assert_eq!(
&NamedType {
name: "Bound",
ty: &SdmTy::Struct(&[&NamedValue {
name: "x",
ty: &U8_SCHEMA
}]),
},
Bound::<bound::Id>::SCHEMA,
);
}

0 comments on commit c0c1ded

Please sign in to comment.