Skip to content

Commit 04c2a56

Browse files
committed
Implement FromSql for tuples up to size 4
- sfackler#626 with feature gate additions
1 parent 29a24c7 commit 04c2a56

File tree

4 files changed

+110
-0
lines changed

4 files changed

+110
-0
lines changed

postgres-types/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ categories = ["database"]
1313
[features]
1414
derive = ["postgres-derive"]
1515
array-impls = ["array-init"]
16+
tuple-impls = []
1617
with-bit-vec-0_6 = ["bit-vec-06"]
1718
with-cidr-0_2 = ["cidr-02"]
1819
with-chrono-0_4 = ["chrono-04"]

postgres-types/src/lib.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,62 @@ impl<'a, T: FromSql<'a>, const N: usize> FromSql<'a> for [T; N] {
642642
}
643643
}
644644

645+
macro_rules! impl_from_sql_tuple {
646+
($n:expr; $($ty_ident:ident),*; $($var_ident:ident),*) => {
647+
impl<'a, $($ty_ident),*> FromSql<'a> for ($($ty_ident,)*)
648+
where
649+
$($ty_ident: FromSql<'a>),*
650+
{
651+
fn from_sql(
652+
_: &Type,
653+
mut raw: &'a [u8],
654+
) -> Result<($($ty_ident,)*), Box<dyn Error + Sync + Send>> {
655+
let num_fields = private::read_be_i32(&mut raw)?;
656+
if num_fields as usize != $n {
657+
return Err(format!(
658+
"Postgres record field count does not match Rust tuple length: {} vs {}",
659+
num_fields,
660+
$n,
661+
).into());
662+
}
663+
664+
$(
665+
let oid = private::read_be_i32(&mut raw)? as u32;
666+
let ty = match Type::from_oid(oid) {
667+
None => {
668+
return Err(format!(
669+
"cannot decode OID {} inside of anonymous record",
670+
oid,
671+
).into());
672+
}
673+
Some(ty) if !$ty_ident::accepts(&ty) => {
674+
return Err(Box::new(WrongType::new::<$ty_ident>(ty.clone())));
675+
}
676+
Some(ty) => ty,
677+
};
678+
let $var_ident = private::read_value(&ty, &mut raw)?;
679+
)*
680+
681+
Ok(($($var_ident,)*))
682+
}
683+
684+
fn accepts(ty: &Type) -> bool {
685+
match ty.kind() {
686+
Kind::Pseudo => *ty == Type::RECORD,
687+
Kind::Composite(fields) => fields.len() == $n,
688+
_ => false,
689+
}
690+
}
691+
}
692+
};
693+
}
694+
695+
impl_from_sql_tuple!(0; ; );
696+
impl_from_sql_tuple!(1; T0; v0);
697+
impl_from_sql_tuple!(2; T0, T1; v0, v1);
698+
impl_from_sql_tuple!(3; T0, T1, T2; v0, v1, v2);
699+
impl_from_sql_tuple!(4; T0, T1, T2, T3; v0, v1, v2, v3);
700+
645701
impl<'a, T: FromSql<'a>> FromSql<'a> for Box<[T]> {
646702
fn from_sql(ty: &Type, raw: &'a [u8]) -> Result<Self, Box<dyn Error + Sync + Send>> {
647703
Vec::<T>::from_sql(ty, raw).map(Vec::into_boxed_slice)

tokio-postgres/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ default = ["runtime"]
2828
runtime = ["tokio/net", "tokio/time"]
2929

3030
array-impls = ["postgres-types/array-impls"]
31+
tuple-impls = ["postgres-types/tuple-impls"]
3132
with-bit-vec-0_6 = ["postgres-types/with-bit-vec-0_6"]
3233
with-chrono-0_4 = ["postgres-types/with-chrono-0_4"]
3334
with-eui48-0_4 = ["postgres-types/with-eui48-0_4"]

tokio-postgres/tests/test/types/mod.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -379,6 +379,58 @@ async fn test_array_array_params() {
379379
.await;
380380
}
381381

382+
#[tokio::test]
383+
async fn tuples() {
384+
let client = connect("user=postgres").await;
385+
386+
let row = client.query_one("SELECT ROW()", &[]).await.unwrap();
387+
row.get::<_, ()>(0);
388+
389+
let row = client.query_one("SELECT ROW(1)", &[]).await.unwrap();
390+
let val: (i32,) = row.get(0);
391+
assert_eq!(val, (1,));
392+
393+
let row = client.query_one("SELECT (1, 'a')", &[]).await.unwrap();
394+
let val: (i32, String) = row.get(0);
395+
assert_eq!(val, (1, "a".into()));
396+
397+
let row = client.query_one("SELECT (1, (2, 3))", &[]).await.unwrap();
398+
let val: (i32, (i32, i32)) = row.get(0);
399+
assert_eq!(val, (1, (2, 3)));
400+
401+
let row = client.query_one("SELECT (1, 2)", &[]).await.unwrap();
402+
let err = row.try_get::<_, (i32, String)>(0).unwrap_err();
403+
match err.source() {
404+
Some(e) if e.is::<WrongType>() => {}
405+
_ => panic!("Unexpected error {:?}", err),
406+
};
407+
408+
let row = client.query_one("SELECT (1, 2, 3)", &[]).await.unwrap();
409+
let err = row.try_get::<_, (i32, i32)>(0).unwrap_err();
410+
assert_eq!(
411+
err.to_string(),
412+
"error deserializing column 0: \
413+
Postgres record field count does not match Rust tuple length: 3 vs 2"
414+
);
415+
416+
client
417+
.batch_execute(
418+
"CREATE TYPE pg_temp.simple AS (
419+
a int,
420+
b text
421+
)",
422+
)
423+
.await
424+
.unwrap();
425+
426+
let row = client
427+
.query_one("SELECT (1, 'a')::simple", &[])
428+
.await
429+
.unwrap();
430+
let val: (i32, String) = row.get(0);
431+
assert_eq!(val, (1, "a".into()));
432+
}
433+
382434
#[allow(clippy::eq_op)]
383435
async fn test_nan_param<T>(sql_type: &str)
384436
where

0 commit comments

Comments
 (0)