Skip to content

Commit c0fa7f4

Browse files
committed
Implement FromSql for tuples up to length 4
This makes it very ergonomic to decode the results of a query like SELECT (1, 'a') where (1, 'a') is returned as an anonymous record type. The big downside to this approach is that only built-in OIDs are supported, as there is no way to know ahead of time what OIDs will be returned, and so we'll only have metadata for the built-in OIDs lying around.
1 parent 598fc0f commit c0fa7f4

File tree

2 files changed

+108
-0
lines changed
  • postgres-types/src
  • tokio-postgres/tests/test/types

2 files changed

+108
-0
lines changed

postgres-types/src/lib.rs

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -617,6 +617,62 @@ impl<'a> FromSql<'a> for IpAddr {
617617
accepts!(INET);
618618
}
619619

620+
macro_rules! impl_from_sql_tuple {
621+
($n:expr; $($ty_ident:ident),*; $($var_ident:ident),*) => {
622+
impl<'a, $($ty_ident),*> FromSql<'a> for ($($ty_ident,)*)
623+
where
624+
$($ty_ident: FromSql<'a>),*
625+
{
626+
fn from_sql(
627+
_: &Type,
628+
mut raw: &'a [u8],
629+
) -> Result<($($ty_ident,)*), Box<dyn Error + Sync + Send>> {
630+
let num_fields = private::read_be_i32(&mut raw)?;
631+
if num_fields as usize != $n {
632+
return Err(format!(
633+
"Postgres record field count does not match Rust tuple length: {} vs {}",
634+
num_fields,
635+
$n,
636+
).into());
637+
}
638+
639+
$(
640+
let oid = private::read_be_i32(&mut raw)? as u32;
641+
let ty = match Type::from_oid(oid) {
642+
None => {
643+
return Err(format!(
644+
"cannot decode OID {} inside of anonymous record",
645+
oid,
646+
).into());
647+
}
648+
Some(ty) if !$ty_ident::accepts(&ty) => {
649+
return Err(Box::new(WrongType::new::<$ty_ident>(ty.clone())));
650+
}
651+
Some(ty) => ty,
652+
};
653+
let $var_ident = private::read_value(&ty, &mut raw)?;
654+
)*
655+
656+
Ok(($($var_ident,)*))
657+
}
658+
659+
fn accepts(ty: &Type) -> bool {
660+
match ty.kind() {
661+
Kind::Pseudo => *ty == Type::RECORD,
662+
Kind::Composite(fields) => fields.len() == $n,
663+
_ => false,
664+
}
665+
}
666+
}
667+
};
668+
}
669+
670+
impl_from_sql_tuple!(0; ; );
671+
impl_from_sql_tuple!(1; T0; v0);
672+
impl_from_sql_tuple!(2; T0, T1; v0, v1);
673+
impl_from_sql_tuple!(3; T0, T1, T2; v0, v1, v2);
674+
impl_from_sql_tuple!(4; T0, T1, T2, T3; v0, v1, v2, v3);
675+
620676
/// An enum representing the nullability of a Postgres value.
621677
pub enum IsNull {
622678
/// The value is NULL.

tokio-postgres/tests/test/types/mod.rs

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -547,6 +547,58 @@ async fn composite() {
547547
}
548548
}
549549

550+
#[tokio::test]
551+
async fn tuples() {
552+
let client = connect("user=postgres").await;
553+
554+
let row = client.query_one("SELECT ROW()", &[]).await.unwrap();
555+
let _: () = row.get(0);
556+
557+
let row = client.query_one("SELECT ROW(1)", &[]).await.unwrap();
558+
let val: (i32,) = row.get(0);
559+
assert_eq!(val, (1,));
560+
561+
let row = client.query_one("SELECT (1, 'a')", &[]).await.unwrap();
562+
let val: (i32, String) = row.get(0);
563+
assert_eq!(val, (1, "a".into()));
564+
565+
let row = client.query_one("SELECT (1, (2, 3))", &[]).await.unwrap();
566+
let val: (i32, (i32, i32)) = row.get(0);
567+
assert_eq!(val, (1, (2, 3)));
568+
569+
let row = client.query_one("SELECT (1, 2)", &[]).await.unwrap();
570+
let err = row.try_get::<_, (i32, String)>(0).unwrap_err();
571+
match err.source() {
572+
Some(e) if e.is::<WrongType>() => {}
573+
_ => panic!("Unexpected error {:?}", err),
574+
};
575+
576+
let row = client.query_one("SELECT (1, 2, 3)", &[]).await.unwrap();
577+
let err = row.try_get::<_, (i32, i32)>(0).unwrap_err();
578+
assert_eq!(
579+
err.to_string(),
580+
"error deserializing column 0: \
581+
Postgres record field count does not match Rust tuple length: 3 vs 2"
582+
);
583+
584+
client
585+
.batch_execute(
586+
"CREATE TYPE pg_temp.simple AS (
587+
a int,
588+
b text
589+
)",
590+
)
591+
.await
592+
.unwrap();
593+
594+
let row = client
595+
.query_one("SELECT (1, 'a')::simple", &[])
596+
.await
597+
.unwrap();
598+
let val: (i32, String) = row.get(0);
599+
assert_eq!(val, (1, "a".into()));
600+
}
601+
550602
#[tokio::test]
551603
async fn enum_() {
552604
let client = connect("user=postgres").await;

0 commit comments

Comments
 (0)