Skip to content

Commit 199966f

Browse files
authored
refactor: extract a new crate arrow-pg (#89)
* feat: initial commit for arrow-pg library * refactor: move more components to arrow-pg * chore: tune workspace dependencies Signed-off-by: Ning Sun <sunning@greptime.com> --------- Signed-off-by: Ning Sun <sunning@greptime.com>
1 parent 4d305bd commit 199966f

File tree

14 files changed

+266
-237
lines changed

14 files changed

+266
-237
lines changed

Cargo.lock

Lines changed: 14 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[workspace]
22
resolver = "2"
3-
members = ["datafusion-postgres", "datafusion-postgres-cli"]
3+
members = ["datafusion-postgres", "datafusion-postgres-cli", "arrow-pg"]
44

55
[workspace.package]
66
version = "0.5.1"
@@ -14,8 +14,14 @@ repository = "https://github.com/datafusion-contrib/datafusion-postgres/"
1414
documentation = "https://docs.rs/crate/datafusion-postgres/"
1515

1616
[workspace.dependencies]
17-
pgwire = "0.30.2"
17+
arrow = "55"
18+
bytes = "1.10.1"
19+
chrono = { version = "0.4", features = ["std"] }
1820
datafusion = { version = "47", default-features = false }
21+
futures = "0.3"
22+
pgwire = "0.30.2"
23+
postgres-types = "0.2"
24+
rust_decimal = { version = "1.37", features = ["db-postgres"] }
1925
tokio = { version = "1", default-features = false }
2026

2127
[profile.release]

arrow-pg/Cargo.toml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
[package]
2+
name = "arrow-pg"
3+
description = "Arrow data mapping and encoding/decoding for Postgres"
4+
version = "0.0.1"
5+
edition.workspace = true
6+
license.workspace = true
7+
authors.workspace = true
8+
keywords.workspace = true
9+
homepage.workspace = true
10+
repository.workspace = true
11+
documentation.workspace = true
12+
readme = "../README.md"
13+
14+
[dependencies]
15+
arrow.workspace = true
16+
bytes.workspace = true
17+
chrono.workspace = true
18+
futures.workspace = true
19+
pgwire.workspace = true
20+
postgres-types.workspace = true
21+
rust_decimal.workspace = true

arrow-pg/src/datatypes.rs

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
use std::sync::Arc;
2+
3+
use arrow::datatypes::*;
4+
use arrow::record_batch::RecordBatch;
5+
use pgwire::api::portal::Format;
6+
use pgwire::api::results::FieldInfo;
7+
use pgwire::api::Type;
8+
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
9+
use pgwire::messages::data::DataRow;
10+
use postgres_types::Kind;
11+
12+
use crate::row_encoder::RowEncoder;
13+
14+
pub fn into_pg_type(arrow_type: &DataType) -> PgWireResult<Type> {
15+
Ok(match arrow_type {
16+
DataType::Null => Type::UNKNOWN,
17+
DataType::Boolean => Type::BOOL,
18+
DataType::Int8 | DataType::UInt8 => Type::CHAR,
19+
DataType::Int16 | DataType::UInt16 => Type::INT2,
20+
DataType::Int32 | DataType::UInt32 => Type::INT4,
21+
DataType::Int64 | DataType::UInt64 => Type::INT8,
22+
DataType::Timestamp(_, tz) => {
23+
if tz.is_some() {
24+
Type::TIMESTAMPTZ
25+
} else {
26+
Type::TIMESTAMP
27+
}
28+
}
29+
DataType::Time32(_) | DataType::Time64(_) => Type::TIME,
30+
DataType::Date32 | DataType::Date64 => Type::DATE,
31+
DataType::Interval(_) => Type::INTERVAL,
32+
DataType::Binary | DataType::FixedSizeBinary(_) | DataType::LargeBinary => Type::BYTEA,
33+
DataType::Float16 | DataType::Float32 => Type::FLOAT4,
34+
DataType::Float64 => Type::FLOAT8,
35+
DataType::Decimal128(_, _) => Type::NUMERIC,
36+
DataType::Utf8 => Type::VARCHAR,
37+
DataType::LargeUtf8 => Type::TEXT,
38+
DataType::List(field) | DataType::FixedSizeList(field, _) | DataType::LargeList(field) => {
39+
match field.data_type() {
40+
DataType::Boolean => Type::BOOL_ARRAY,
41+
DataType::Int8 | DataType::UInt8 => Type::CHAR_ARRAY,
42+
DataType::Int16 | DataType::UInt16 => Type::INT2_ARRAY,
43+
DataType::Int32 | DataType::UInt32 => Type::INT4_ARRAY,
44+
DataType::Int64 | DataType::UInt64 => Type::INT8_ARRAY,
45+
DataType::Timestamp(_, tz) => {
46+
if tz.is_some() {
47+
Type::TIMESTAMPTZ_ARRAY
48+
} else {
49+
Type::TIMESTAMP_ARRAY
50+
}
51+
}
52+
DataType::Time32(_) | DataType::Time64(_) => Type::TIME_ARRAY,
53+
DataType::Date32 | DataType::Date64 => Type::DATE_ARRAY,
54+
DataType::Interval(_) => Type::INTERVAL_ARRAY,
55+
DataType::FixedSizeBinary(_) | DataType::Binary => Type::BYTEA_ARRAY,
56+
DataType::Float16 | DataType::Float32 => Type::FLOAT4_ARRAY,
57+
DataType::Float64 => Type::FLOAT8_ARRAY,
58+
DataType::Utf8 => Type::VARCHAR_ARRAY,
59+
DataType::LargeUtf8 => Type::TEXT_ARRAY,
60+
struct_type @ DataType::Struct(_) => Type::new(
61+
Type::RECORD_ARRAY.name().into(),
62+
Type::RECORD_ARRAY.oid(),
63+
Kind::Array(into_pg_type(struct_type)?),
64+
Type::RECORD_ARRAY.schema().into(),
65+
),
66+
list_type => {
67+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
68+
"ERROR".to_owned(),
69+
"XX000".to_owned(),
70+
format!("Unsupported List Datatype {list_type}"),
71+
))));
72+
}
73+
}
74+
}
75+
DataType::Utf8View => Type::TEXT,
76+
DataType::Dictionary(_, value_type) => into_pg_type(value_type)?,
77+
DataType::Struct(fields) => {
78+
let name: String = fields
79+
.iter()
80+
.map(|x| x.name().clone())
81+
.reduce(|a, b| a + ", " + &b)
82+
.map(|x| format!("({x})"))
83+
.unwrap_or("()".to_string());
84+
let kind = Kind::Composite(
85+
fields
86+
.iter()
87+
.map(|x| {
88+
into_pg_type(x.data_type())
89+
.map(|_type| postgres_types::Field::new(x.name().clone(), _type))
90+
})
91+
.collect::<Result<Vec<_>, PgWireError>>()?,
92+
);
93+
Type::new(name, Type::RECORD.oid(), kind, Type::RECORD.schema().into())
94+
}
95+
_ => {
96+
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
97+
"ERROR".to_owned(),
98+
"XX000".to_owned(),
99+
format!("Unsupported Datatype {arrow_type}"),
100+
))));
101+
}
102+
})
103+
}
104+
105+
pub fn arrow_schema_to_pg_fields(schema: &Schema, format: &Format) -> PgWireResult<Vec<FieldInfo>> {
106+
schema
107+
.fields()
108+
.iter()
109+
.enumerate()
110+
.map(|(idx, f)| {
111+
let pg_type = into_pg_type(f.data_type())?;
112+
Ok(FieldInfo::new(
113+
f.name().into(),
114+
None,
115+
None,
116+
pg_type,
117+
format.format_for(idx),
118+
))
119+
})
120+
.collect::<PgWireResult<Vec<FieldInfo>>>()
121+
}
122+
123+
pub fn encode_recordbatch(
124+
fields: Arc<Vec<FieldInfo>>,
125+
record_batch: RecordBatch,
126+
) -> Box<impl Iterator<Item = PgWireResult<DataRow>>> {
127+
let mut row_stream = RowEncoder::new(record_batch, fields);
128+
Box::new(std::iter::from_fn(move || row_stream.next_row()))
129+
}

datafusion-postgres/src/encoder/mod.rs renamed to arrow-pg/src/encoder.rs

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,27 @@
1+
use std::error::Error;
12
use std::io::Write;
23
use std::str::FromStr;
34
use std::sync::Arc;
45

6+
use arrow::array::*;
7+
use arrow::datatypes::*;
58
use bytes::BufMut;
69
use bytes::BytesMut;
710
use chrono::{NaiveDate, NaiveDateTime};
8-
use datafusion::arrow::array::*;
9-
use datafusion::arrow::datatypes::*;
10-
use list_encoder::encode_list;
1111
use pgwire::api::results::DataRowEncoder;
1212
use pgwire::api::results::FieldFormat;
13-
use pgwire::error::{ErrorInfo, PgWireError, PgWireResult};
13+
use pgwire::error::PgWireError;
14+
use pgwire::error::PgWireResult;
1415
use pgwire::types::ToSqlText;
1516
use postgres_types::{ToSql, Type};
1617
use rust_decimal::Decimal;
17-
use struct_encoder::encode_struct;
1818
use timezone::Tz;
1919

20-
pub mod list_encoder;
21-
pub mod row_encoder;
22-
pub mod struct_encoder;
20+
use crate::error::ToSqlError;
21+
use crate::list_encoder::encode_list;
22+
use crate::struct_encoder::encode_struct;
2323

24-
trait Encoder {
24+
pub trait Encoder {
2525
fn encode_field_with_type_and_format<T>(
2626
&mut self,
2727
value: &T,
@@ -61,7 +61,7 @@ impl ToSql for EncodedValue {
6161
&self,
6262
_ty: &Type,
6363
out: &mut BytesMut,
64-
) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>>
64+
) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
6565
where
6666
Self: Sized,
6767
{
@@ -80,7 +80,7 @@ impl ToSql for EncodedValue {
8080
&self,
8181
ty: &Type,
8282
out: &mut BytesMut,
83-
) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
83+
) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>> {
8484
self.to_sql(ty, out)
8585
}
8686
}
@@ -90,7 +90,7 @@ impl ToSqlText for EncodedValue {
9090
&self,
9191
_ty: &Type,
9292
out: &mut BytesMut,
93-
) -> Result<postgres_types::IsNull, Box<dyn std::error::Error + Sync + Send>>
93+
) -> Result<postgres_types::IsNull, Box<dyn Error + Send + Sync>>
9494
where
9595
Self: Sized,
9696
{
@@ -261,16 +261,13 @@ fn get_numeric_128_value(
261261
}
262262
_ => unreachable!(),
263263
};
264-
PgWireError::UserError(Box::new(ErrorInfo::new(
265-
"ERROR".to_owned(),
266-
"XX000".to_owned(),
267-
message.to_owned(),
268-
)))
264+
// TODO: add error type in PgWireError
265+
PgWireError::ApiError(ToSqlError::from(message))
269266
})
270267
.map(Some)
271268
}
272269

273-
fn encode_value<T: Encoder>(
270+
pub fn encode_value<T: Encoder>(
274271
encoder: &mut T,
275272
arr: &Arc<dyn Array>,
276273
idx: usize,
@@ -387,8 +384,7 @@ fn encode_value<T: Encoder>(
387384
}
388385
let ts_array = arr.as_any().downcast_ref::<TimestampSecondArray>().unwrap();
389386
if let Some(tz) = timezone {
390-
let tz = Tz::from_str(tz.as_ref())
391-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
387+
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
392388
let value = ts_array
393389
.value_as_datetime_with_tz(idx, tz)
394390
.map(|d| d.fixed_offset());
@@ -411,8 +407,7 @@ fn encode_value<T: Encoder>(
411407
.downcast_ref::<TimestampMillisecondArray>()
412408
.unwrap();
413409
if let Some(tz) = timezone {
414-
let tz = Tz::from_str(tz.as_ref())
415-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
410+
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
416411
let value = ts_array
417412
.value_as_datetime_with_tz(idx, tz)
418413
.map(|d| d.fixed_offset());
@@ -435,8 +430,7 @@ fn encode_value<T: Encoder>(
435430
.downcast_ref::<TimestampMicrosecondArray>()
436431
.unwrap();
437432
if let Some(tz) = timezone {
438-
let tz = Tz::from_str(tz.as_ref())
439-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
433+
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
440434
let value = ts_array
441435
.value_as_datetime_with_tz(idx, tz)
442436
.map(|d| d.fixed_offset());
@@ -459,8 +453,7 @@ fn encode_value<T: Encoder>(
459453
.downcast_ref::<TimestampNanosecondArray>()
460454
.unwrap();
461455
if let Some(tz) = timezone {
462-
let tz = Tz::from_str(tz.as_ref())
463-
.map_err(|e| PgWireError::ApiError(Box::new(e)))?;
456+
let tz = Tz::from_str(tz.as_ref()).map_err(ToSqlError::from)?;
464457
let value = ts_array
465458
.value_as_datetime_with_tz(idx, tz)
466459
.map(|d| d.fixed_offset());
@@ -483,11 +476,10 @@ fn encode_value<T: Encoder>(
483476
let fields = match type_.kind() {
484477
postgres_types::Kind::Composite(fields) => fields,
485478
_ => {
486-
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
487-
"ERROR".to_owned(),
488-
"XX000".to_owned(),
489-
format!("Failed to unwrap a composite type from type {}", type_),
490-
))))
479+
return Err(PgWireError::ApiError(ToSqlError::from(format!(
480+
"Failed to unwrap a composite type from type {}",
481+
type_
482+
))));
491483
}
492484
};
493485
let value = encode_struct(arr, idx, fields, format)?;
@@ -517,14 +509,10 @@ fn encode_value<T: Encoder>(
517509
.or_else(|| get_dict_values!(UInt32Type))
518510
.or_else(|| get_dict_values!(UInt64Type))
519511
.ok_or_else(|| {
520-
PgWireError::UserError(Box::new(ErrorInfo::new(
521-
"ERROR".to_owned(),
522-
"XX000".to_owned(),
523-
format!(
524-
"Unsupported dictionary key type for value type {}",
525-
value_type
526-
),
527-
)))
512+
ToSqlError::from(format!(
513+
"Unsupported dictionary key type for value type {}",
514+
value_type
515+
))
528516
})?;
529517

530518
// If the dictionary has only one value, treat it as a primitive
@@ -536,15 +524,11 @@ fn encode_value<T: Encoder>(
536524
}
537525
}
538526
_ => {
539-
return Err(PgWireError::UserError(Box::new(ErrorInfo::new(
540-
"ERROR".to_owned(),
541-
"XX000".to_owned(),
542-
format!(
543-
"Unsupported Datatype {} and array {:?}",
544-
arr.data_type(),
545-
&arr
546-
),
547-
))))
527+
return Err(PgWireError::ApiError(ToSqlError::from(format!(
528+
"Unsupported Datatype {} and array {:?}",
529+
arr.data_type(),
530+
&arr
531+
))));
548532
}
549533
}
550534

arrow-pg/src/error.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
pub type ToSqlError = Box<dyn std::error::Error + Sync + Send>;

arrow-pg/src/lib.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
pub mod datatypes;
2+
pub mod encoder;
3+
mod error;
4+
pub mod list_encoder;
5+
pub mod row_encoder;
6+
pub mod struct_encoder;

0 commit comments

Comments
 (0)