Skip to content

Commit 9f500ff

Browse files
committed
Add support for deriving FromSql and ToSql for structs with references
1 parent 89ea051 commit 9f500ff

File tree

6 files changed

+91
-4
lines changed

6 files changed

+91
-4
lines changed

postgres-derive-test/src/compile-fail/invalid-types.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,9 @@ enum FromSqlEnum {
2222
Foo(i32),
2323
}
2424

25+
#[derive(FromSql)]
26+
struct FromSqlTypeParameter<T> {
27+
foo: T,
28+
}
29+
2530
fn main() {}

postgres-derive-test/src/compile-fail/invalid-types.stderr

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,3 +33,9 @@ error: non-C-like enums are not supported
3333
|
3434
22 | Foo(i32),
3535
| ^^^^^^^^
36+
37+
error: #[derive(FromSql)] does not support type parameters.
38+
--> $DIR/invalid-types.rs:26:28
39+
|
40+
26 | struct FromSqlTypeParameter<T> {
41+
| ^^^

postgres-derive-test/src/composites.rs

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,32 @@ fn wrong_type() {
215215
.unwrap_err();
216216
assert!(err.source().unwrap().is::<WrongType>());
217217
}
218+
219+
#[test]
220+
fn struct_with_references() {
221+
#[derive(FromSql, ToSql, Debug, PartialEq)]
222+
#[postgres(name = "item")]
223+
struct Item<'a, 'b: 'a> {
224+
name: &'a str,
225+
data: &'b [u8],
226+
}
227+
228+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
229+
conn.batch_execute(
230+
"CREATE TYPE pg_temp.item AS (
231+
name TEXT,
232+
data BYTEA
233+
);",
234+
)
235+
.unwrap();
236+
237+
let item = Item {
238+
name: "foobar",
239+
data: b"12345",
240+
};
241+
242+
let row = conn.query_one("SELECT $1::item", &[&item]).unwrap();
243+
let result: Item<'_, '_> = row.get(0);
244+
assert_eq!(item.name, result.name);
245+
assert_eq!(item.data, result.data);
246+
}

postgres-derive-test/src/domains.rs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,3 +119,23 @@ fn domain_in_composite() {
119119
)],
120120
);
121121
}
122+
123+
#[test]
124+
fn struct_with_reference() {
125+
#[derive(FromSql, ToSql, Debug, PartialEq)]
126+
struct SessionId<'b>(&'b [u8]);
127+
128+
let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap();
129+
conn.execute(
130+
"CREATE DOMAIN pg_temp.\"SessionId\" AS bytea CHECK(octet_length(VALUE) = 16);",
131+
&[],
132+
)
133+
.unwrap();
134+
135+
let session_id = b"0123456789abcdef";
136+
let row = conn
137+
.query_one("SELECT $1::\"SessionId\"", &[&SessionId(session_id)])
138+
.unwrap();
139+
let result: SessionId<'_> = row.get(0);
140+
assert_eq!(session_id, result.0);
141+
}

postgres-derive/src/fromsql.rs

Lines changed: 28 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use crate::composites::Field;
88
use crate::enums::Variant;
99
use crate::overrides::Overrides;
1010

11+
const DEFAULT_LIFETIME: &str = "de";
12+
1113
pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
1214
let overrides = Overrides::extract(&input.attrs)?;
1315

@@ -58,10 +60,33 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result<TokenStream, Error> {
5860
};
5961

6062
let ident = &input.ident;
63+
let mut generics = input.generics;
64+
if generics.type_params().count() > 0 {
65+
return Err(Error::new_spanned(
66+
&generics,
67+
"#[derive(FromSql)] does not support type parameters.",
68+
));
69+
}
70+
71+
let generics_clone = &generics.clone();
72+
let (_, type_generics, _) = generics_clone.split_for_impl();
73+
74+
let lifetime = syn::Lifetime::new(&format!("'{}", DEFAULT_LIFETIME), Span::call_site());
75+
let mut lifetime_def = syn::LifetimeDef::new(lifetime.clone());
76+
let lifetimes: Vec<syn::Lifetime> = generics.lifetimes().map(|l| l.lifetime.clone()).collect();
77+
lifetime_def.bounds = syn::punctuated::Punctuated::new();
78+
for l in lifetimes {
79+
lifetime_def.bounds.push(l);
80+
}
81+
generics
82+
.params
83+
.push(syn::GenericParam::Lifetime(lifetime_def));
84+
let (impl_generics, _, _) = generics.split_for_impl();
85+
6186
let out = quote! {
62-
impl<'a> postgres_types::FromSql<'a> for #ident {
63-
fn from_sql(_type: &postgres_types::Type, buf: &'a [u8])
64-
-> std::result::Result<#ident,
87+
impl #impl_generics postgres_types::FromSql<#lifetime> for #ident #type_generics {
88+
fn from_sql(_type: &postgres_types::Type, buf: & #lifetime [u8])
89+
-> std::result::Result<#ident #type_generics,
6590
std::boxed::Box<dyn std::error::Error +
6691
std::marker::Sync +
6792
std::marker::Send>> {

postgres-derive/src/tosql.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,10 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result<TokenStream, Error> {
5555
};
5656

5757
let ident = &input.ident;
58+
let generics = &input.generics;
59+
let (impl_generics, type_generics, where_clause) = generics.split_for_impl();
5860
let out = quote! {
59-
impl postgres_types::ToSql for #ident {
61+
impl #impl_generics postgres_types::ToSql for #ident #type_generics #where_clause {
6062
fn to_sql(&self,
6163
_type: &postgres_types::Type,
6264
buf: &mut postgres_types::private::BytesMut)

0 commit comments

Comments
 (0)