Skip to content

Rename attr #5

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jan 11, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
improvements, restructuring, comments and rename attr
  • Loading branch information
remkop22 committed Jan 11, 2023
commit 463c0b28c14c65ecfbed403f0fc6ff1961a09d2e
182 changes: 131 additions & 51 deletions postgres-from-row-derive/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use darling::ast::{self, Style};
use darling::{FromDeriveInput, FromField, ToTokens};
use darling::{ast::Data, Error, FromDeriveInput, FromField, ToTokens};
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::{parse_macro_input, DeriveInput, Ident};
use syn::{parse_macro_input, DeriveInput, Ident, Result};

#[proc_macro_derive(FromRowTokioPostgres, attributes(from_row))]
pub fn derive_from_row_tokio_postgres(input: TokenStream) -> TokenStream {
Expand All @@ -14,6 +14,7 @@ pub fn derive_from_row_postgres(input: TokenStream) -> TokenStream {
derive_from_row(input, quote::format_ident!("postgres"))
}

/// Calls the fallible entry point and writes any errors to the tokenstream.
fn derive_from_row(input: TokenStream, module: Ident) -> TokenStream {
let derive_input = parse_macro_input!(input as DeriveInput);
match try_derive_from_row(&derive_input, module) {
Expand All @@ -22,11 +23,16 @@ fn derive_from_row(input: TokenStream, module: Ident) -> TokenStream {
}
}

fn try_derive_from_row(input: &DeriveInput, module: Ident) -> Result<TokenStream, darling::Error> {
/// Fallible entry point for generating a `FromRow` implementation
fn try_derive_from_row(
input: &DeriveInput,
module: Ident,
) -> std::result::Result<TokenStream, Error> {
let from_row_derive = DeriveFromRow::from_derive_input(input)?;
from_row_derive.generate(module)
Ok(from_row_derive.generate(module)?)
}

/// Main struct for deriving `FromRow` for a struct.
#[derive(Debug, FromDeriveInput)]
#[darling(
attributes(from_row),
Expand All @@ -36,59 +42,60 @@ fn try_derive_from_row(input: &DeriveInput, module: Ident) -> Result<TokenStream
struct DeriveFromRow {
ident: syn::Ident,
generics: syn::Generics,
data: ast::Data<(), FromRowField>,
data: Data<(), FromRowField>,
}

impl DeriveFromRow {
fn generate(self, module: Ident) -> Result<TokenStream, darling::Error> {
let ident = &self.ident;
/// Validates all fields
fn validate(&self) -> Result<()> {
for field in self.fields() {
field.validate()?;
}

let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
Ok(())
}

let fields = self
.data
.take_struct()
.ok_or_else(|| darling::Error::unsupported_shape("enum").with_span(&self.ident))?;
/// Generates any additional where clause predicates needed for the fields in this struct.
fn predicates(&self, module: &Ident) -> Result<Vec<TokenStream2>> {
let mut predicates = Vec::new();

let fields = match fields.style {
Style::Unit => {
return Err(darling::Error::unsupported_shape("unit struct").with_span(&self.ident))
}
Style::Tuple => {
return Err(darling::Error::unsupported_shape("tuple struct").with_span(&self.ident))
}
Style::Struct => fields.fields,
};
for field in self.fields() {
field.add_predicates(&module, &mut predicates)?;
}

let from_row_fields = fields
Ok(predicates)
}

/// Provides a slice of this struct's fields.
fn fields(&self) -> &[FromRowField] {
match &self.data {
Data::Struct(fields) => &fields.fields,
_ => panic!("invalid shape"),
}
}

/// Generate the `FromRow` implementation.
fn generate(self, module: Ident) -> Result<TokenStream> {
self.validate()?;

let ident = &self.ident;

let (impl_generics, ty_generics, where_clause) = self.generics.split_for_impl();
let original_predicates = where_clause.clone().map(|w| &w.predicates).into_iter();
let predicates = self.predicates(&module)?;

let from_row_fields = self
.fields()
.iter()
.map(|f| f.generate_from_row(&module))
.collect::<syn::Result<Vec<_>>>()?;

let try_from_row_fields = fields
let try_from_row_fields = self
.fields()
.iter()
.map(|f| f.generate_try_from_row(&module))
.collect::<syn::Result<Vec<_>>>()?;

let original_predicates = where_clause.clone().map(|w| &w.predicates).into_iter();
let mut predicates = Vec::new();

for field in fields.iter() {
let target_ty = &field.target_ty()?;
let ty = &field.ty;
predicates.push(if field.flatten {
quote! (#target_ty: postgres_from_row::FromRow)
} else {
quote! (#target_ty: for<'a> #module::types::FromSql<'a>)
});

if field.from.is_some() {
predicates.push(quote!(#ty: std::convert::From<#target_ty>))
} else if field.try_from.is_some() {
predicates.push(quote!(#ty: std::convert::From<#target_ty>))
}
}

Ok(quote! {
impl #impl_generics postgres_from_row::FromRow for #ident #ty_generics where #(#original_predicates),* #(#predicates),* {

Expand All @@ -109,19 +116,52 @@ impl DeriveFromRow {
}
}

/// A single field inside of a struct that derives `FromRow`
#[derive(Debug, FromField)]
#[darling(attributes(from_row), forward_attrs(allow, doc, cfg))]
struct FromRowField {
/// The identifier of this field.
ident: Option<syn::Ident>,
/// The type specified in this field.
ty: syn::Type,
/// Wether to flatten this field. Flattening means calling the `FromRow` implementation
/// of `self.ty` instead of extracting it directly from the row.
#[darling(default)]
flatten: bool,
/// Optionaly use this type as the target for `FromRow` or `FromSql`, and then
/// call `TryFrom::try_from` to convert it the `self.ty`.
try_from: Option<String>,
/// Optionaly use this type as the target for `FromRow` or `FromSql`, and then
/// call `From::from` to convert it the `self.ty`.
from: Option<String>,
/// Override the name of the actual sql column instead of using `self.ident`.
/// Is not compatible with `flatten` since no column is needed there.
rename: Option<String>,
}

impl FromRowField {
fn target_ty(&self) -> syn::Result<proc_macro2::TokenStream> {
/// Checks wether this field has a valid combination of attributes
fn validate(&self) -> Result<()> {
if self.from.is_some() && self.try_from.is_some() {
return Err(Error::custom(
r#"can't combine `#[from_row(from = "..")]` with `#[from_row(try_from = "..")]`"#,
)
.into());
}

if self.rename.is_some() && self.flatten {
return Err(Error::custom(
r#"can't combine `#[from_row(flatten)]` with `#[from_row(rename = "..")]`"#,
)
.into());
}

Ok(())
}

/// Returns a tokenstream of the type that should be returned from either
/// `FromRow` (when using `flatten`) or `FromSql`.
fn target_ty(&self) -> Result<TokenStream2> {
if let Some(from) = &self.from {
Ok(from.parse()?)
} else if let Some(try_from) = &self.try_from {
Expand All @@ -131,17 +171,56 @@ impl FromRowField {
}
}

fn generate_from_row(&self, module: &Ident) -> syn::Result<proc_macro2::TokenStream> {
/// Returns the name that maps to the actuall sql column
/// By default this is the same as the rust field name but can be overwritten by `#[from_row(rename = "..")]`.
fn column_name(&self) -> String {
self.rename
.as_ref()
.map(Clone::clone)
.unwrap_or_else(|| self.ident.as_ref().unwrap().to_string())
}

/// Pushes the needed where clause predicates for this field.
///
/// By default this is `T: for<'a> postgres::types::FromSql<'a>`,
/// when using `flatten` it's: `T: postgres_from_row::FromRow`
/// and when using either `from` or `try_from` attributes it additionally pushes this bound:
/// `T: std::convert::From<R>`, where `T` is the type specified in the struct and `R` is the
/// type specified in the `[try]_from` attribute.
fn add_predicates(&self, module: &Ident, predicates: &mut Vec<TokenStream2>) -> Result<()> {
let target_ty = &self.target_ty()?;
let ty = &self.ty;

predicates.push(if self.flatten {
quote! (#target_ty: postgres_from_row::FromRow)
} else {
quote! (#target_ty: for<'a> #module::types::FromSql<'a>)
});

if self.from.is_some() {
predicates.push(quote!(#ty: std::convert::From<#target_ty>))
} else if self.try_from.is_some() {
let try_from = quote!(std::convert::TryFrom<#target_ty>);

predicates.push(quote!(#ty: #try_from));
predicates.push(quote!(#module::Error: std::convert::From<<#ty as #try_from>::Error>));
predicates.push(quote!(<#ty as #try_from>::Error: std::fmt::Debug));
}

Ok(())
}

/// Generate the line needed to retrievee this field from a row when calling `from_row`.
fn generate_from_row(&self, module: &Ident) -> Result<TokenStream2> {
let ident = self.ident.as_ref().unwrap();
let str_ident = ident.to_string();
let column_name = self.column_name();
let field_ty = &self.ty;

let target_ty = self.target_ty()?;

let mut base = if self.flatten {
quote!(<#target_ty as postgres_from_row::FromRow>::from_row(row))
} else {
quote!(#module::Row::get::<&str, #target_ty>(row, #str_ident))
quote!(#module::Row::get::<&str, #target_ty>(row, #column_name))
};

if self.from.is_some() {
Expand All @@ -153,16 +232,17 @@ impl FromRowField {
Ok(quote!(#ident: #base))
}

fn generate_try_from_row(&self, module: &Ident) -> syn::Result<proc_macro2::TokenStream> {
/// Generate the line needed to retrieve this field from a row when calling `try_from_row`.
fn generate_try_from_row(&self, module: &Ident) -> Result<TokenStream2> {
let ident = self.ident.as_ref().unwrap();
let str_ident = ident.to_string();
let column_name = self.column_name();
let field_ty = &self.ty;
let target_ty = self.target_ty()?;

let mut base = if self.flatten {
quote!(<#target_ty as postgres_from_row::FromRow>::try_from_row(row)?)
} else {
quote!(#module::Row::try_get::<&str, #target_ty>(row, #str_ident)?)
quote!(#module::Row::try_get::<&str, #target_ty>(row, #column_name)?)
};

if self.from.is_some() {
Expand Down