Skip to content

Commit b4f43ce

Browse files
committed
Add support for deriving FromSql and ToSql for structs with references
1 parent 89ea051 commit b4f43ce

File tree

9 files changed

+123
-17
lines changed

9 files changed

+123
-17
lines changed

postgres-derive-test/src/compile-fail/invalid-types.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,9 @@ enum FromSqlEnum {
2222
Foo(i32),
2323
}
2424

25+
#[derive(FromSql)]
26+
struct FromSqlTypeParameter<T> {
27+
foo: T,
28+
}
29+
2530
fn main() {}

postgres-derive-test/src/compile-fail/invalid-types.stderr

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,9 @@ error: non-C-like enums are not supported
3333
|
3434
22 | Foo(i32),
3535
| ^^^^^^^^
36+
37+
error: #[derive(FromSql)] does not support type parameters.
38+
--> $DIR/invalid-types.rs:26:28
39+
|
40+
26 | struct FromSqlTypeParameter<T> {
41+
| ^^^

postgres-derive-test/src/composites.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,32 @@ fn wrong_type() {
215215
.unwrap_err();
216216
assert!(err.source().unwrap().is::<WrongType>());
217217
}
218+
219+
#[test]
220+
fn struct_with_references() {
221+
#[derive(FromSql, ToSql, Debug, PartialEq)]
222+
#[postgres(name = "item")]
223+
struct Item<'a, 'b: 'a> {
224+
name: &'a str,
225+
data: &'b [u8],
226+
}
227+
228+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
229+
conn.batch_execute(
230+
"CREATE TYPE pg_temp.item AS (
231+
name TEXT,
232+
data BYTEA
233+
);",
234+
)
235+
.unwrap();
236+
237+
let item = Item {
238+
name: "foobar",
239+
data: b"12345",
240+
};
241+
242+
let row = conn.query_one("SELECT $1::item", &[&item]).unwrap();
243+
let result: Item<'_, '_> = row.get(0);
244+
assert_eq!(item.name, result.name);
245+
assert_eq!(item.data, result.data);
246+
}

postgres-derive-test/src/domains.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,23 @@ fn domain_in_composite() {
119119
)],
120120
);
121121
}
122+
123+
#[test]
124+
fn struct_with_reference() {
125+
#[derive(FromSql, ToSql, Debug, PartialEq)]
126+
struct SessionId<'b>(&'b [u8]);
127+
128+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
129+
conn.execute(
130+
"CREATE DOMAIN pg_temp.\"SessionId\" AS bytea CHECK(octet_length(VALUE) = 16);",
131+
&[],
132+
)
133+
.unwrap();
134+
135+
let session_id = b"0123456789abcdef";
136+
let row = conn
137+
.query_one("SELECT $1::\"SessionId\"", &[&SessionId(session_id)])
138+
.unwrap();
139+
let result: SessionId<'_> = row.get(0);
140+
assert_eq!(session_id, result.0);
141+
}

postgres-derive/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,6 @@ proc-macro = true
1212
test = false
1313

1414
[dependencies]
15-
syn = "1.0"
15+
syn = {version = "1.0", features = ["visit-mut"]}
1616
proc-macro2 = "1.0"
1717
quote = "1.0"

postgres-derive/src/accepts.rs

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,21 @@
11
use proc_macro2::{Span, TokenStream};
22
use quote::quote;
33
use std::iter;
4+
use syn::visit_mut::VisitMut;
45
use syn::Ident;
56

67
use crate::composites::Field;
8+
use crate::composites::RenameLifetimes;
79
use crate::enums::Variant;
810

9-
pub fn domain_body(name: &str, field: &syn::Field) -> TokenStream {
10-
let ty = &field.ty;
11-
11+
pub fn domain_body(name: &str, trait_: &str, field: &syn::Field) -> TokenStream {
12+
let ty = if trait_ == "FromSql" {
13+
let mut ty = field.ty.clone();
14+
RenameLifetimes.visit_type_mut(&mut ty);
15+
ty
16+
} else {
17+
field.ty.clone()
18+
};
1219
quote! {
1320
if type_.name() != #name {
1421
return false;
@@ -57,7 +64,15 @@ pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream
5764
let trait_ = Ident::new(trait_, Span::call_site());
5865
let traits = iter::repeat(&trait_);
5966
let field_names = fields.iter().map(|f| &f.name);
60-
let field_types = fields.iter().map(|f| &f.type_);
67+
let field_types = fields.iter().map(|f| {
68+
if trait_ == "FromSql" {
69+
let mut type_ = f.type_.clone();
70+
RenameLifetimes.visit_type_mut(&mut type_);
71+
type_
72+
} else {
73+
f.type_.clone()
74+
}
75+
});
6176

6277
quote! {
6378
if type_.name() != #name {

postgres-derive/src/composites.rs

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
use syn::{Error, Ident, Type};
1+
use proc_macro2::Span;
2+
use syn::visit_mut::VisitMut;
3+
use syn::{Error, Ident, Lifetime, Type};
24

35
use crate::overrides::Overrides;
46

@@ -20,3 +22,13 @@ impl Field {
2022
})
2123
}
2224
}
25+
26+
pub struct RenameLifetimes;
27+
28+
pub const DEFAULT_LIFETIME: &str = "a";
29+
30+
impl VisitMut for RenameLifetimes {
31+
fn visit_lifetime_mut(&mut self, node: &mut Lifetime) {
32+
node.ident = Ident::new(DEFAULT_LIFETIME, Span::call_site())
33+
}
34+
}

postgres-derive/src/fromsql.rs

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
use proc_macro2::{Span, TokenStream};
22
use quote::quote;
33
use std::iter;
4-
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident};
4+
use syn::visit_mut::VisitMut;
5+
use syn::{Data, DataStruct, DeriveInput, Error, Fields, Ident, Lifetime};
56

67
use crate::accepts;
7-
use crate::composites::Field;
8+
use crate::composites::{Field, RenameLifetimes, DEFAULT_LIFETIME};
89
use crate::enums::Variant;
910
use crate::overrides::Overrides;
1011

@@ -58,10 +59,25 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
5859
};
5960

6061
let ident = &input.ident;
62+
let mut generics = input.generics;
63+
if generics.type_params().count() > 0 {
64+
return Err(Error::new_spanned(
65+
&generics,
66+
"#[derive(FromSql)] does not support type parameters.",
67+
));
68+
}
69+
let type_generics = if generics.lifetimes().count() > 0 {
70+
RenameLifetimes.visit_generics_mut(&mut generics);
71+
let (_impl, typ, _where) = generics.split_for_impl();
72+
Some(typ)
73+
} else {
74+
None
75+
};
76+
let lifetime = Lifetime::new(&format!("'{}", DEFAULT_LIFETIME), Span::call_site());
6177
let out = quote! {
62-
impl<'a> postgres_types::FromSql<'a> for #ident {
63-
fn from_sql(_type: &postgres_types::Type, buf: &'a [u8])
64-
-> std::result::Result<#ident,
78+
impl <#lifetime> postgres_types::FromSql<#lifetime> for #ident #type_generics {
79+
fn from_sql(_type: &postgres_types::Type, buf: & #lifetime [u8])
80+
-> std::result::Result<#ident #type_generics,
6581
std::boxed::Box<dyn std::error::Error +
6682
std::marker::Sync +
6783
std::marker::Send>> {
@@ -97,9 +113,9 @@ fn enum_body(ident: &Ident, variants: &[Variant]) -> TokenStream {
97113

98114
// Domains are sometimes but not always just represented by the bare type (!?)
99115
fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream {
100-
let ty = &field.ty;
101-
let normal_body = accepts::domain_body(name, field);
102-
116+
let mut ty = field.ty.clone();
117+
RenameLifetimes.visit_type_mut(&mut ty);
118+
let normal_body = accepts::domain_body(name, "FromSql", field);
103119
quote! {
104120
if <#ty as postgres_types::FromSql>::accepts(type_) {
105121
return true;
@@ -110,7 +126,8 @@ fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream {
110126
}
111127

112128
fn domain_body(ident: &Ident, field: &syn::Field) -> TokenStream {
113-
let ty = &field.ty;
129+
let mut ty = field.ty.clone();
130+
RenameLifetimes.visit_type_mut(&mut ty);
114131
quote! {
115132
<#ty as postgres_types::FromSql>::from_sql(_type, buf).map(#ident)
116133
}

postgres-derive/src/tosql.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
3030
..
3131
}) if fields.unnamed.len() == 1 => {
3232
let field = fields.unnamed.first().unwrap();
33-
(accepts::domain_body(&name, &field), domain_body())
33+
(accepts::domain_body(&name, "ToSql", &field), domain_body())
3434
}
3535
Data::Struct(DataStruct {
3636
fields: Fields::Named(ref fields),
@@ -55,8 +55,10 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
5555
};
5656

5757
let ident = &input.ident;
58+
let generics = &input.generics;
59+
let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
5860
let out = quote! {
59-
impl postgres_types::ToSql for #ident {
61+
impl #impl_generics postgres_types::ToSql for #ident #type_generics #where_clause {
6062
fn to_sql(&self,
6163
_type: &postgres_types::Type,
6264
buf: &mut postgres_types::private::BytesMut)

0 commit comments

Comments
 (0)