Skip to content

refactor: extract a new crate arrow-pg #89

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 8 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[workspace]
resolver = "2"
members = ["datafusion-postgres", "datafusion-postgres-cli"]
members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg"]

[workspace.package]
version = "0.5.1"
Expand All @@ -14,8 +14,14 @@ repository = "https://github.com/datafusion-contrib/datafusion-postgres/"
documentation = "https://docs.rs/crate/datafusion-postgres/"

[workspace.dependencies]
pgwire = "0.30.2"
arrow = "55"
bytes = "1.10.1"
chrono = { version = "0.4", features = ["std"] }
datafusion = { version = "47", default-features = false }
futures = "0.3"
pgwire = "0.30.2"
postgres-types = "0.2"
rust_decimal = { version = "1.37", features = ["db-postgres"] }
tokio = { version = "1", default-features = false }

[profile.release]
Expand Down
21 changes: 21 additions & 0 deletions arrow-pg/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
[package]
name = "arrow-pg"
description = "Arrow data mapping and encoding/decoding for Postgres"
version = "0.0.1"
edition.workspace = true
license.workspace = true
authors.workspace = true
keywords.workspace = true
homepage.workspace = true
repository.workspace = true
documentation.workspace = true
readme = "../README.md"

[dependencies]
arrow.workspace = true
bytes.workspace = true
chrono.workspace = true
futures.workspace = true
pgwire.workspace = true
postgres-types.workspace = true
rust_decimal.workspace = true
129 changes: 129 additions & 0 deletions arrow-pg/src/datatypes.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use std::sync::Arc;

use arrow::datatypes::*;
use arrow::record_batch::RecordBatch;
use pgwire::api::portal::Format;
use pgwire::api::results::FieldInfo;
use pgwire::api::Type;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
use postgres_types::Kind;

use crate::row_encoder::RowEncoder;

pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
Ok(match arrow_type {
DataType::Null => Type::UNKNOWN,
DataType::Boolean => Type::BOOL,
DataType::Int8 | DataType::UInt8 => Type::CHAR,
DataType::Int16 | DataType::UInt16 => Type::INT2,
DataType::Int32 | DataType::UInt32 => Type::INT4,
DataType::Int64 | DataType::UInt64 => Type::INT8,
DataType::Timestamp(_, tz) => {
if tz.is_some() {
Type::TIMESTAMPTZ
} else {
Type::TIMESTAMP
}
}
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
DataType::Date32 | DataType::Date64 => Type::DATE,
DataType::Interval(_) => Type::INTERVAL,
DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
DataType::Float64 => Type::FLOAT8,
DataType::Decimal128(_, _) => Type::NUMERIC,
DataType::Utf8 => Type::VARCHAR,
DataType::LargeUtf8 => Type::TEXT,
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
match field.data_type() {
DataType::Boolean => Type::BOOL_ARRAY,
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
DataType::Timestamp(_, tz) => {
if tz.is_some() {
Type::TIMESTAMPTZ_ARRAY
} else {
Type::TIMESTAMP_ARRAY
}
}
DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
DataType::Interval(_) => Type::INTERVAL_ARRAY,
DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY,
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
DataType::Float64 => Type::FLOAT8_ARRAY,
DataType::Utf8 => Type::VARCHAR_ARRAY,
DataType::LargeUtf8 => Type::TEXT_ARRAY,
struct_type @ DataType::Struct(_) => Type::new(
Type::RECORD_ARRAY.name().into(),
Type::RECORD_ARRAY.oid(),
Kind::Array(into_pg_type(struct_type)?),
Type::RECORD_ARRAY.schema().into(),
),
list_type => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("Unsupported List Datatype {list_type}"),
))));
}
}
}
DataType::Utf8View => Type::TEXT,
DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
DataType::Struct(fields) => {
let name: String = fields
.iter()
.map(|x| x.name().clone())
.reduce(|a, b| a + ", " + &b)
.map(|x| format!("({x})"))
.unwrap_or("()".to_string());
let kind = Kind::Composite(
fields
.iter()
.map(|x| {
into_pg_type(x.data_type())
.map(|_type| postgres_types::Field::new(x.name().clone(), _type))
})
.collect::<Result<Vec<_>, PgWireError>>()?,
);
Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
}
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("Unsupported Datatype {arrow_type}"),
))));
}
})
}

pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
schema
.fields()
.iter()
.enumerate()
.map(|(idx, f)| {
let pg_type = into_pg_type(f.data_type())?;
Ok(FieldInfo::new(
f.name().into(),
None,
None,
pg_type,
format.format_for(idx),
))
})
.collect::<PgWireResult<Vec<FieldInfo>>>()
}

pub fn encode_recordbatch(
fields: Arc<Vec<FieldInfo>>,
record_batch: RecordBatch,
) -> Box<impl Iterator<Item = PgWireResult<DataRow>>> {
let mut row_stream = RowEncoder::new(record_batch, fields);
Box::new(std::iter::from_fn(move || row_stream.next_row()))
}
80 changes: 32 additions & 48 deletions datafusion-postgres/src/encoder/mod.rs → arrow-pg/src/encoder.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,27 @@
use std::error::Error;
use std::io::Write;
use std::str::FromStr;
use std::sync::Arc;

use arrow::array::*;
use arrow::datatypes::*;
use bytes::BufMut;
use bytes::BytesMut;
use chrono::{NaiveDate, NaiveDateTime};
use datafusion::arrow::array::*;
use datafusion::arrow::datatypes::*;
use list_encoder::encode_list;
use pgwire::api::results::DataRowEncoder;
use pgwire::api::results::FieldFormat;
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
use pgwire::error::PgWireError;
use pgwire::error::PgWireResult;
use pgwire::types::ToSqlText;
use postgres_types::{ToSql, Type};
use rust_decimal::Decimal;
use struct_encoder::encode_struct;
use timezone::Tz;

pub mod list_encoder;
pub mod row_encoder;
pub mod struct_encoder;
use crate::error::ToSqlError;
use crate::list_encoder::encode_list;
use crate::struct_encoder::encode_struct;

trait Encoder {
pub trait Encoder {
fn encode_field_with_type_and_format<T>(
&mut self,
value: &T,
Expand Down Expand Up @@ -61,7 +61,7 @@ impl ToSql for EncodedValue {
&self,
_ty: &Type,
out: &mut BytesMut,
) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>>
) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
where
Self: Sized,
{
Expand All @@ -80,7 +80,7 @@ impl ToSql for EncodedValue {
&self,
ty: &Type,
out: &mut BytesMut,
) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
self.to_sql(ty, out)
}
}
Expand All @@ -90,7 +90,7 @@ impl ToSqlText for EncodedValue {
&self,
_ty: &Type,
out: &mut BytesMut,
) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>>
) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
where
Self: Sized,
{
Expand Down Expand Up @@ -261,16 +261,13 @@ fn get_numeric_128_value(
}
_ => unreachable!(),
};
PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
message.to_owned(),
)))
// TODO: add error type in PgWireError
PgWireError::ApiError(ToSqlError::from(message))
})
.map(Some)
}

fn encode_value<T: Encoder>(
pub fn encode_value<T: Encoder>(
encoder: &mut T,
arr: &Arc<dyn Array>,
idx: usize,
Expand Down Expand Up @@ -387,8 +384,7 @@ fn encode_value<T: Encoder>(
}
let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
if let Some(tz) = timezone {
let tz = Tz::from_str(tz.as_ref())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
let value = ts_array
.value_as_datetime_with_tz(idx, tz)
.map(|d| d.fixed_offset());
Expand All @@ -411,8 +407,7 @@ fn encode_value<T: Encoder>(
.downcast_ref::<TimestampMillisecondArray>()
.unwrap();
if let Some(tz) = timezone {
let tz = Tz::from_str(tz.as_ref())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
let value = ts_array
.value_as_datetime_with_tz(idx, tz)
.map(|d| d.fixed_offset());
Expand All @@ -435,8 +430,7 @@ fn encode_value<T: Encoder>(
.downcast_ref::<TimestampMicrosecondArray>()
.unwrap();
if let Some(tz) = timezone {
let tz = Tz::from_str(tz.as_ref())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
let value = ts_array
.value_as_datetime_with_tz(idx, tz)
.map(|d| d.fixed_offset());
Expand All @@ -459,8 +453,7 @@ fn encode_value<T: Encoder>(
.downcast_ref::<TimestampNanosecondArray>()
.unwrap();
if let Some(tz) = timezone {
let tz = Tz::from_str(tz.as_ref())
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
let value = ts_array
.value_as_datetime_with_tz(idx, tz)
.map(|d| d.fixed_offset());
Expand All @@ -483,11 +476,10 @@ fn encode_value<T: Encoder>(
let fields = match type_.kind() {
postgres_types::Kind::Composite(fields) => fields,
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!("Failed to unwrap a composite type from type {}", type_),
))))
return Err(PgWireError::ApiError(ToSqlError::from(format!(
"Failed to unwrap a composite type from type {}",
type_
))));
}
};
let value = encode_struct(arr, idx, fields, format)?;
Expand Down Expand Up @@ -517,14 +509,10 @@ fn encode_value<T: Encoder>(
.or_else(|| get_dict_values!(UInt32Type))
.or_else(|| get_dict_values!(UInt64Type))
.ok_or_else(|| {
PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!(
"Unsupported dictionary key type for value type {}",
value_type
),
)))
ToSqlError::from(format!(
"Unsupported dictionary key type for value type {}",
value_type
))
})?;

// If the dictionary has only one value, treat it as a primitive
Expand All @@ -536,15 +524,11 @@ fn encode_value<T: Encoder>(
}
}
_ => {
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
"ERROR".to_owned(),
"XX000".to_owned(),
format!(
"Unsupported Datatype {} and array {:?}",
arr.data_type(),
&arr
),
))))
return Err(PgWireError::ApiError(ToSqlError::from(format!(
"Unsupported Datatype {} and array {:?}",
arr.data_type(),
&arr
))));
}
}

Expand Down
1 change: 1 addition & 0 deletions arrow-pg/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub type ToSqlError = Box<dyn std::error::Error + Sync + Send>;
6 changes: 6 additions & 0 deletions arrow-pg/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
pub mod datatypes;
pub mod encoder;
mod error;
pub mod list_encoder;
pub mod row_encoder;
pub mod struct_encoder;
Loading
Loading