Skip to content

Commit

Permalink
Fix default collation deserialization
Browse files Browse the repository at this point in the history
We've run into an issue where the database default collation doesn't
necessarily exist as a "real" collation in pg_collation. This code
specializes the deserialization code to handle this: if we don't find
the collation we're trying to deserialize, check if the name matches
that of the default collation and if so assume it's the default
collation.
  • Loading branch information
JLockerman committed Jun 22, 2022
1 parent 5beb961 commit 8a4ec73
Showing 1 changed file with 27 additions and 17 deletions.
44 changes: 27 additions & 17 deletions extension/src/serialization/collations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ use flat_serialize::{impl_flat_serializable, FlatSerializable, WrapErr};

use serde::{Deserialize, Serialize};

use once_cell::sync::Lazy;

use pg_sys::{Datum, Oid};
use pgx::*;

Expand Down Expand Up @@ -74,7 +76,21 @@ struct FormData_pg_database {
#[allow(non_camel_case_types)]
type Form_pg_database = *mut FormData_pg_database;

static DEFAULT_COLLATION_NAME: Lazy<CString> = Lazy::new(|| unsafe {
let tuple = pg_sys::SearchSysCache1(pg_sys::SysCacheIdentifier_DATABASEOID as _, pg_sys::MyDatabaseId as _);
if tuple.is_null() {
pgx::error!("no database info");
}

let database_tuple: Form_pg_database = get_struct(tuple);
let collation_name = (*database_tuple).datcollate.data.as_ptr();
let collation_name_len = CStr::from_ptr(collation_name).to_bytes().len();
let collation_name = pg_sys::pg_server_to_any(collation_name, collation_name_len as _, pg_sys::pg_enc_PG_UTF8 as _);
let collation_name = CStr::from_ptr(collation_name);
let collation_name = collation_name.clone();
pg_sys::ReleaseSysCache(tuple);
CString::from(collation_name)
});

impl Serialize for PgCollationId {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
Expand Down Expand Up @@ -107,32 +123,21 @@ impl Serialize for PgCollationId {

// the 'default' collation isn't really a collation, and we need to
// look in pg_database to discover what real collation is
let mut db_tuple = None;
let collation_name =
if self.0 == DEFAULT_COLLATION_OID {
let tuple = pg_sys::SearchSysCache1(pg_sys::SysCacheIdentifier_DATABASEOID as _, pg_sys::MyDatabaseId as _);
if tuple.is_null() {
pgx::error!("no database info");
}
db_tuple = Some(tuple);

let database_tuple: Form_pg_database = get_struct(tuple);
(*database_tuple).datcollate.data.as_ptr()
&*DEFAULT_COLLATION_NAME
} else {
(*collation_tuple).collname.data.as_ptr()
let collation_name = (*collation_tuple).collname.data.as_ptr();
let collation_name_len = CStr::from_ptr(collation_name).to_bytes().len();
let collation_name = pg_sys::pg_server_to_any(collation_name, collation_name_len as _, pg_sys::pg_enc_PG_UTF8 as _);
CStr::from_ptr(collation_name)
};

let collation_name_len = CStr::from_ptr(collation_name).to_bytes().len();
let collation_name = pg_sys::pg_server_to_any(collation_name, collation_name_len as _, pg_sys::pg_enc_PG_UTF8 as _);
let collation_name = CStr::from_ptr(collation_name);
let collation_name = collation_name.to_str().unwrap();

let qualified_name: (&str, &str) = (namespace, collation_name);
layout = Some(qualified_name);
let res = layout.serialize(serializer);
if let Some(db_tuple) = db_tuple {
pg_sys::ReleaseSysCache(db_tuple);
}

pg_sys::ReleaseSysCache(tuple);
res
}
Expand Down Expand Up @@ -201,6 +206,11 @@ impl<'de> Deserialize<'de> for PgCollationId {
}

if collation_id == pg_sys::InvalidOid {
// The default collation doesn't necessarily exist in the
// collations catalog, so check that specially
if name == &**DEFAULT_COLLATION_NAME {
return Ok(PgCollationId(100))
}
return Err(D::Error::custom(format!(
"invalid collation {:?}.{:?}",
namespace, name
Expand Down

0 comments on commit 8a4ec73

Please sign in to comment.