Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
fail-fast: false
matrix:
os:
- ubuntu-20.04
- ubuntu-24.04
toolchain:
- 1.68.0

Expand Down
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
/target
Cargo.lock
.idea/
11 changes: 5 additions & 6 deletions convergence-arrow/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "convergence-arrow"
version = "0.16.0"
version = "0.17.0"
authors = ["Ruan Pearce-Authers <ruanpa@outlook.com>"]
edition = "2018"
description = "Utils for bridging Apache Arrow and PostgreSQL's wire protocol"
Expand All @@ -10,9 +10,8 @@ repository = "https://github.com/returnString/convergence"
[dependencies]
tokio = { version = "1" }
async-trait = "0.1"
datafusion = "38"
convergence = { path = "../convergence", version = "0.16.0" }
chrono = "0.4"

[dev-dependencies]
datafusion = "43"
convergence = { path = "../convergence", version = "0.17.0" }
chrono = "=0.4.39"
tokio-postgres = { version = "0.7", features = [ "with-chrono-0_4" ] }
rust_decimal = { version = "1.36.0", features = ["default", "db-postgres"] }
9 changes: 5 additions & 4 deletions convergence-arrow/examples/datafusion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use convergence::server::{self, BindOptions};
use convergence_arrow::datafusion::DataFusionEngine;
use convergence_arrow::metadata::Catalog;
use datafusion::arrow::datatypes::DataType;
use datafusion::catalog::schema::MemorySchemaProvider;
use datafusion::catalog::{CatalogProvider, MemoryCatalogProvider};
use datafusion::catalog_common::memory::MemorySchemaProvider;
use datafusion::catalog::CatalogProvider;
use datafusion::catalog_common::MemoryCatalogProvider;
use datafusion::logical_expr::Volatility;
use datafusion::physical_plan::ColumnarValue;
use datafusion::prelude::*;
Expand Down Expand Up @@ -35,15 +36,15 @@ async fn new_engine() -> DataFusionEngine {
ctx.register_udf(create_udf(
"pg_backend_pid",
vec![],
Arc::new(DataType::Int32),
DataType::Int32,
Volatility::Stable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Int32(Some(0))))),
));

ctx.register_udf(create_udf(
"current_schema",
vec![],
Arc::new(DataType::Utf8),
DataType::Utf8,
Volatility::Stable,
Arc::new(|_| Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some("public".to_owned()))))),
));
Expand Down
4 changes: 3 additions & 1 deletion convergence-arrow/src/metadata.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
use datafusion::arrow::array::{ArrayRef, Int32Builder, StringBuilder, UInt32Builder};
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::arrow::record_batch::RecordBatch;
use datafusion::catalog::schema::{MemorySchemaProvider, SchemaProvider};
use datafusion::catalog::CatalogProvider;
use datafusion::catalog::SchemaProvider;
use datafusion::catalog_common::memory::MemorySchemaProvider;
use datafusion::datasource::{MemTable, TableProvider};
use datafusion::error::DataFusionError;
use std::convert::TryInto;
Expand Down Expand Up @@ -153,6 +154,7 @@ impl MetadataBuilder {
}

/// Wrapper catalog supporting generation of pg metadata (e.g. pg_catalog schema).
#[derive(Debug)]
pub struct Catalog {
wrapped: Arc<dyn CatalogProvider>,
}
Expand Down
12 changes: 8 additions & 4 deletions convergence-arrow/src/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
use convergence::protocol::{DataTypeOid, ErrorResponse, FieldDescription, SqlState};
use convergence::protocol_ext::DataRowBatch;
use datafusion::arrow::array::{
BooleanArray, Date32Array, Date64Array, Float16Array, Float32Array, Float64Array, Int16Array, Int32Array,
Int64Array, Int8Array, StringArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
BooleanArray, Date32Array, Date64Array, Decimal128Array, Float16Array,
Float32Array, Float64Array, Int16Array, Int32Array, Int64Array, Int8Array, StringArray,
StringViewArray, TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array
};
use datafusion::arrow::datatypes::{DataType, Schema, TimeUnit};
use datafusion::arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -47,7 +48,9 @@ pub fn record_batch_to_rows(arrow_batch: &RecordBatch, pg_batch: &mut DataRowBat
DataType::Float16 => row.write_float4(array_val!(Float16Array, col, row_idx).to_f32()),
DataType::Float32 => row.write_float4(array_val!(Float32Array, col, row_idx)),
DataType::Float64 => row.write_float8(array_val!(Float64Array, col, row_idx)),
DataType::Decimal128(p, s) => row.write_numeric_16(array_val!(Decimal128Array, col, row_idx), p, s),
DataType::Utf8 => row.write_string(array_val!(StringArray, col, row_idx)),
DataType::Utf8View => row.write_string(array_val!(StringViewArray, col, row_idx)),
DataType::Date32 => {
row.write_date(array_val!(Date32Array, col, row_idx, value_as_date).ok_or_else(|| {
ErrorResponse::error(SqlState::InvalidDatetimeFormat, "unsupported date type")
Expand Down Expand Up @@ -102,7 +105,8 @@ pub fn data_type_to_oid(ty: &DataType) -> Result<DataTypeOid, ErrorResponse> {
DataType::UInt64 => DataTypeOid::Int8,
DataType::Float16 | DataType::Float32 => DataTypeOid::Float4,
DataType::Float64 => DataTypeOid::Float8,
DataType::Utf8 => DataTypeOid::Text,
DataType::Decimal128(_, _) => DataTypeOid::Numeric,
DataType::Utf8 | DataType::Utf8View => DataTypeOid::Text,
DataType::Date32 | DataType::Date64 => DataTypeOid::Date,
DataType::Timestamp(_, None) => DataTypeOid::Timestamp,
other => {
Expand Down
19 changes: 15 additions & 4 deletions convergence-arrow/tests/test_arrow.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@ use convergence::protocol_ext::DataRowBatch;
use convergence::server::{self, BindOptions};
use convergence::sqlparser::ast::Statement;
use convergence_arrow::table::{record_batch_to_rows, schema_to_field_desc};
use datafusion::arrow::array::{ArrayRef, Date32Array, Float32Array, Int32Array, StringArray, TimestampSecondArray};
use datafusion::arrow::array::{ArrayRef, Date32Array, Decimal128Array, Float32Array, Int32Array, StringArray, StringViewArray, TimestampSecondArray};
use datafusion::arrow::datatypes::{DataType, Field, Schema, TimeUnit};
use datafusion::arrow::record_batch::RecordBatch;
use std::sync::Arc;
use rust_decimal::Decimal;
use tokio_postgres::{connect, NoTls};

struct ArrowPortal {
Expand All @@ -31,20 +32,24 @@ impl ArrowEngine {
fn new() -> Self {
let int_col = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef;
let float_col = Arc::new(Float32Array::from(vec![1.5, 2.5, 3.5])) as ArrayRef;
let decimal_col = Arc::new(Decimal128Array::from(vec![11, 22, 33]).with_precision_and_scale(2, 0).unwrap()) as ArrayRef;
let string_col = Arc::new(StringArray::from(vec!["a", "b", "c"])) as ArrayRef;
let string_view_col = Arc::new(StringViewArray::from(vec!["aa", "bb", "cc"])) as ArrayRef;
let ts_col = Arc::new(TimestampSecondArray::from(vec![1577836800, 1580515200, 1583020800])) as ArrayRef;
let date_col = Arc::new(Date32Array::from(vec![0, 1, 2])) as ArrayRef;

let schema = Schema::new(vec![
Field::new("int_col", DataType::Int32, true),
Field::new("float_col", DataType::Float32, true),
Field::new("decimal_col", DataType::Decimal128(2, 0), true),
Field::new("string_col", DataType::Utf8, true),
Field::new("string_view_col", DataType::Utf8View, true),
Field::new("ts_col", DataType::Timestamp(TimeUnit::Second, None), true),
Field::new("date_col", DataType::Date32, true),
]);

Self {
batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, string_col, ts_col, date_col])
batch: RecordBatch::try_new(Arc::new(schema), vec![int_col, float_col, decimal_col, string_col, string_view_col, ts_col, date_col])
.expect("failed to create batch"),
}
}
Expand Down Expand Up @@ -89,8 +94,8 @@ async fn basic_data_types() {
let rows = client.query("select 1", &[]).await.unwrap();
let get_row = |idx: usize| {
let row = &rows[idx];
let cols: (i32, f32, &str, NaiveDateTime, NaiveDate) =
(row.get(0), row.get(1), row.get(2), row.get(3), row.get(4));
let cols: (i32, f32, Decimal, &str, &str, NaiveDateTime, NaiveDate) =
(row.get(0), row.get(1), row.get(2), row.get(3), row.get(4), row.get(5), row.get(6));
cols
};

Expand All @@ -99,7 +104,9 @@ async fn basic_data_types() {
(
1,
1.5,
Decimal::from(11),
"a",
"aa",
NaiveDate::from_ymd_opt(2020, 1, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
Expand All @@ -112,7 +119,9 @@ async fn basic_data_types() {
(
2,
2.5,
Decimal::from(22),
"b",
"bb",
NaiveDate::from_ymd_opt(2020, 2, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
Expand All @@ -125,7 +134,9 @@ async fn basic_data_types() {
(
3,
3.5,
Decimal::from(33),
"c",
"cc",
NaiveDate::from_ymd_opt(2020, 3, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
Expand Down
7 changes: 3 additions & 4 deletions convergence/Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "convergence"
version = "0.16.0"
version = "0.17.0"
authors = ["Ruan Pearce-Authers <ruanpa@outlook.com>"]
edition = "2018"
description = "Write servers that speak PostgreSQL's wire protocol"
Expand All @@ -15,7 +15,6 @@ bytes = "1"
futures = "0.3"
sqlparser = "0.46"
async-trait = "0.1"
chrono = "0.4"

[dev-dependencies]
chrono = "=0.4.39"
rust_decimal = { version = "1.36.0", features = ["default", "db-postgres"] }
tokio-postgres = "0.7"
2 changes: 2 additions & 0 deletions convergence/src/protocol.rs
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ data_types! {
Float4 = 700, 4
Float8 = 701, 8

Numeric = 1700, -1

Date = 1082, 4
Timestamp = 1114, 8

Expand Down
22 changes: 21 additions & 1 deletion convergence/src/protocol_ext.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
use crate::protocol::{ConnectionCodec, FormatCode, ProtocolError, RowDescription};
use bytes::{BufMut, BytesMut};
use chrono::{NaiveDate, NaiveDateTime};
use rust_decimal::Decimal;
use tokio_postgres::types::{ToSql, Type};
use tokio_util::codec::Encoder;

/// Supports batched rows for e.g. returning portal result sets.
Expand Down Expand Up @@ -131,14 +133,32 @@ impl<'a> DataRowWriter<'a> {
}
}

/// Writes a numeric value for the next column.
pub fn write_numeric_16(&mut self, val: i128, _p: &u8, s: &i8) {
let decimal = Decimal::from_i128_with_scale(val, *s as u32);
match self.parent.format_code {
FormatCode::Text => {
self.write_string(&decimal.to_string())
}
FormatCode::Binary => {
let numeric_type = Type::from_oid(1700).expect("failed to create numeric type");
let mut buf = BytesMut::new();
decimal.to_sql(&numeric_type, &mut buf)
.expect("failed to write numeric");

self.write_value(&buf.freeze())
}
};
}

primitive_write!(write_int2, i16);
primitive_write!(write_int4, i32);
primitive_write!(write_int8, i64);
primitive_write!(write_float4, f32);
primitive_write!(write_float8, f64);
}

impl<'a> Drop for DataRowWriter<'a> {
impl Drop for DataRowWriter<'_> {
fn drop(&mut self) {
assert_eq!(
self.parent.num_cols, self.current_col,
Expand Down
6 changes: 3 additions & 3 deletions convergence/tests/test_connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,16 @@ async fn extended_query_flow() {
async fn simple_query_flow() {
let client = setup().await;
let messages = client.simple_query("select 1").await.unwrap();
assert_eq!(messages.len(), 2);
assert_eq!(messages.len(), 3);

let row = match &messages[0] {
let row = match &messages[1] {
SimpleQueryMessage::Row(row) => row,
_ => panic!("expected row"),
};

assert_eq!(row.get(0), Some("1"));

let num_rows = match &messages[1] {
let num_rows = match &messages[2] {
SimpleQueryMessage::CommandComplete(rows) => *rows,
_ => panic!("expected command complete"),
};
Expand Down
Loading