Skip to content

Add a Row-like interface to any PostgreSQL composite type #565

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 76 additions & 0 deletions postgres-protocol/src/types/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1045,3 +1045,79 @@ impl Inet {
self.netmask
}
}

/// A fallible iterator over the fields of a composite type.
pub struct CompositeTypeRanges<'a> {
buf: &'a [u8],
len: usize,
remaining: u16,
}

impl<'a> CompositeTypeRanges<'a> {
/// Returns a fallible iterator over the fields of the composite type.
#[inline]
pub fn new(buf: &'a [u8], len: usize, remaining: u16) -> CompositeTypeRanges<'a> {
CompositeTypeRanges {
buf,
len,
remaining,
}
}
}

impl<'a> FallibleIterator for CompositeTypeRanges<'a> {
type Item = Option<std::ops::Range<usize>>;
type Error = std::io::Error;

#[inline]
fn next(&mut self) -> std::io::Result<Option<Option<std::ops::Range<usize>>>> {
if self.remaining == 0 {
if self.buf.is_empty() {
return Ok(None);
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidInput,
"invalid buffer length: compositetyperanges is not empty",
));
}
}

self.remaining -= 1;

// Binary format of a composite type:
// [for each field]
// <OID of field's type: 4 bytes>
// [if value is NULL]
// <-1: 4 bytes>
// [else]
// <length of value: 4 bytes>
// <value: <length> bytes>
// [end if]
// [end for]
// https://www.postgresql.org/message-id/16CCB2D3-197E-4D9F-BC6F-9B123EA0D40D%40phlo.org
// https://github.com/postgres/postgres/blob/29e321cdd63ea48fd0223447d58f4742ad729eb0/src/backend/utils/adt/rowtypes.c#L736

let _oid = self.buf.read_i32::<BigEndian>()?;
let len = self.buf.read_i32::<BigEndian>()?;
if len < 0 {
Ok(Some(None))
} else {
let len = len as usize;
if self.buf.len() < len {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"unexpected EOF",
));
}
let base = self.len - self.buf.len();
self.buf = &self.buf[len as usize..];
Ok(Some(Some(base..base + len)))
}
}

#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
let len = self.remaining as usize;
(len, Some(len))
}
}
2 changes: 1 addition & 1 deletion tokio-postgres/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ pub use crate::error::Error;
pub use crate::generic_client::GenericClient;
pub use crate::portal::Portal;
pub use crate::query::RowStream;
pub use crate::row::{Row, SimpleQueryRow};
pub use crate::row::{CompositeType, Row, SimpleQueryRow};
pub use crate::simple_query::SimpleQueryStream;
#[cfg(feature = "runtime")]
pub use crate::socket::Socket;
Expand Down
129 changes: 128 additions & 1 deletion tokio-postgres/src/row.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

use crate::row::sealed::{AsName, Sealed};
use crate::statement::Column;
use crate::types::{FromSql, Type, WrongType};
use crate::types::{Field, FromSql, Kind, Type, WrongType};
use crate::{Error, Statement};
use byteorder::{BigEndian, ByteOrder};
use fallible_iterator::FallibleIterator;
use postgres_protocol::message::backend::DataRowBody;
use postgres_protocol::types::CompositeTypeRanges;
use std::fmt;
use std::ops::Range;
use std::str;
Expand All @@ -31,6 +33,12 @@ impl AsName for String {
}
}

impl AsName for Field {
fn as_name(&self) -> &str {
self.name()
}
}

/// A trait implemented by types that can index into columns of a row.
///
/// This cannot be implemented outside of this crate.
Expand Down Expand Up @@ -175,6 +183,125 @@ impl Row {
}
}

/// A PostgreSQL composite type.
/// Fields of a type can be accessed using `CompositeType::get` and `CompositeType::try_get` methods.
pub struct CompositeType<'a> {
type_: Type,
body: &'a [u8],
ranges: Vec<Option<Range<usize>>>,
}

impl<'a> FromSql<'a> for CompositeType<'a> {
fn from_sql(
type_: &Type,
body: &'a [u8],
) -> Result<CompositeType<'a>, Box<dyn std::error::Error + Sync + Send>> {
match *type_.kind() {
Kind::Composite(_) => {
let fields: &[Field] = composite_type_fields(&type_);
if body.len() < 4 {
let message = format!("invalid composite type body length: {}", body.len());
return Err(message.into());
}
let num_fields: i32 = BigEndian::read_i32(&body[0..4]);
if num_fields as usize != fields.len() {
let message =
format!("invalid field count: {} vs {}", num_fields, fields.len());
return Err(message.into());
}
let ranges = CompositeTypeRanges::new(&body[4..], body.len(), num_fields as u16)
.collect()
.map_err(Error::parse)?;
Ok(CompositeType {
type_: type_.clone(),
body,
ranges,
})
}
_ => Err(format!("expected composite type, got {}", type_).into()),
}
}
fn accepts(ty: &Type) -> bool {
match *ty.kind() {
Kind::Composite(_) => true,
_ => false,
}
}
}

fn composite_type_fields(type_: &Type) -> &[Field] {
match type_.kind() {
Kind::Composite(ref fields) => fields,
_ => unreachable!(),
}
}

impl<'a> CompositeType<'a> {
/// Returns information about the fields of the composite type.
pub fn fields(&self) -> &[Field] {
composite_type_fields(&self.type_)
}

/// Determines if the composite contains no values.
pub fn is_empty(&self) -> bool {
self.len() == 0
}

/// Returns the number of fields of the composite type.
pub fn len(&self) -> usize {
self.fields().len()
}

/// Deserializes a value from the composite type.
///
/// The value can be specified either by its numeric index, or by its field name.
///
/// # Panics
///
/// Panics if the index is out of bounds or if the value cannot be converted to the specified type.
pub fn get<'b, I, T>(&'b self, idx: I) -> T
where
I: RowIndex + fmt::Display,
T: FromSql<'b>,
{
match self.get_inner(&idx) {
Ok(ok) => ok,
Err(err) => panic!("error retrieving column {}: {}", idx, err),
}
}

/// Like `CompositeType::get`, but returns a `Result` rather than panicking.
pub fn try_get<'b, I, T>(&'b self, idx: I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'b>,
{
self.get_inner(&idx)
}

fn get_inner<'b, I, T>(&'b self, idx: &I) -> Result<T, Error>
where
I: RowIndex + fmt::Display,
T: FromSql<'b>,
{
let idx = match idx.__idx(self.fields()) {
Some(idx) => idx,
None => return Err(Error::column(idx.to_string())),
};

let ty = self.fields()[idx].type_();
if !T::accepts(ty) {
return Err(Error::from_sql(
Box::new(WrongType::new::<T>(ty.clone())),
idx,
));
}

let buf = self.ranges[idx].clone().map(|r| &self.body[r]);
FromSql::from_sql_nullable(ty, buf).map_err(|e| Error::from_sql(e, idx))
}
}

/// A row of data returned from the database by a simple query.
pub struct SimpleQueryRow {
columns: Arc<[String]>,
Expand Down
55 changes: 54 additions & 1 deletion tokio-postgres/tests/test/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use tokio_postgres::error::SqlState;
use tokio_postgres::tls::{NoTls, NoTlsStream};
use tokio_postgres::types::{Kind, Type};
use tokio_postgres::{
AsyncMessage, Client, Config, Connection, Error, IsolationLevel, SimpleQueryMessage,
AsyncMessage, Client, CompositeType, Config, Connection, Error, IsolationLevel,
SimpleQueryMessage,
};

mod binary_copy;
Expand Down Expand Up @@ -762,3 +763,55 @@ async fn query_opt() {
.err()
.unwrap();
}

#[tokio::test]
async fn composite_type() {
let client = connect("user=postgres").await;

client
.batch_execute(
"
CREATE TYPE pg_temp.message AS (
id INTEGER,
content TEXT,
link TEXT
);
CREATE TYPE pg_temp.person AS (
id INTEGER,
name TEXT,
messages message[],
email TEXT
);

",
)
.await
.unwrap();

let row = client
.query_one(
"select (123,'alice',ARRAY[(1,'message1',NULL)::message,(2,'message2',NULL)::message],NULL)::person",
&[],
)
.await
.unwrap();

let person: CompositeType<'_> = row.get(0);

assert_eq!(person.get::<_, Option<i32>>("id"), Some(123));
assert_eq!(person.get::<_, Option<&str>>("name"), Some("alice"));
assert_eq!(person.get::<_, Option<&str>>("email"), None);

let messages: Vec<CompositeType<'_>> = person.get("messages");

assert_eq!(messages.len(), 2);
assert_eq!(messages[0].get::<_, Option<i32>>("id"), Some(1));
assert_eq!(
messages[0].get::<_, Option<&str>>("content"),
Some("message1")
);
assert_eq!(messages[0].get::<_, Option<&str>>("link"), None);
assert_eq!(messages[1].get::<_, Option<i32>>(0), Some(2));
assert_eq!(messages[1].get::<_, Option<&str>>(1), Some("message2"));
assert_eq!(messages[1].get::<_, Option<&str>>(2), None);
}