Skip to content
This repository was archived by the owner on Dec 29, 2021. It is now read-only.

Commit 8ae790a

Browse files
committed
2 parents 88e1f77 + f9841b7 commit 8ae790a

File tree

5 files changed

+518
-28
lines changed

5 files changed

+518
-28
lines changed

src/error.rs

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1-
use arrow::error::ArrowError;
21
use std::error::Error;
32

3+
use arrow::error::ArrowError;
4+
45
#[derive(Debug, Clone, PartialEq)]
56
pub enum DataFrameError {
67
MemoryError(String),
@@ -10,6 +11,7 @@ pub enum DataFrameError {
1011
IoError(String),
1112
NoneError,
1213
ArrowError(String),
14+
SqlError(String),
1315
}
1416

1517
impl From<ArrowError> for DataFrameError {
@@ -36,4 +38,10 @@ impl From<std::str::Utf8Error> for DataFrameError {
3638
}
3739
}
3840

41+
impl From<postgres::error::Error> for DataFrameError {
42+
fn from(error: postgres::error::Error) -> Self {
43+
DataFrameError::SqlError(error.to_string())
44+
}
45+
}
46+
3947
pub type Result<T> = ::std::result::Result<T, DataFrameError>;

src/io/sql/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
pub mod postgres;
22

3+
use std::sync::Arc;
4+
35
use arrow::datatypes::Schema;
46
use arrow::record_batch::RecordBatch;
57

@@ -19,3 +21,9 @@ pub trait SqlDataSource {
1921
batch_size: usize,
2022
) -> Result<Vec<RecordBatch>>;
2123
}
24+
25+
pub trait SqlDataSink {
26+
fn create_table(connection: &str, table_name: &str, schema: &Arc<Schema>) -> Result<()>;
27+
fn write_to_table(connection: &str, table_name: &str, batches: &Vec<RecordBatch>)
28+
-> Result<()>;
29+
}

src/io/sql/postgres/mod.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
//! An interface for reading and writing record batches to and from PostgreSQL
2+
3+
pub mod reader;
4+
pub mod writer;
5+
6+
/// PGCOPY header
7+
pub const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
8+
pub const EPOCH_DAYS: i32 = 10957;
9+
pub const EPOCH_MICROS: i64 = 946684800000000;
10+
11+
pub struct Postgres;

src/io/sql/postgres.rs renamed to src/io/sql/postgres/reader.rs

Lines changed: 15 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,5 @@
1-
//! An experimental interface for reading and writing record batches to and from PostgreSQL
2-
31
use std::convert::TryFrom;
4-
use std::{
5-
io::{BufRead, BufReader, Cursor, Read, Seek},
6-
sync::Arc,
7-
};
2+
use std::{io::Read, sync::Arc};
83

94
use arrow::array::*;
105
use arrow::buffer::Buffer;
@@ -15,13 +10,10 @@ use chrono::Timelike;
1510
use postgres::types::*;
1611
use postgres::{Client, NoTls, Row};
1712

13+
use super::{Postgres, EPOCH_DAYS, EPOCH_MICROS, MAGIC};
1814
use crate::error::DataFrameError;
1915
use crate::io::sql::SqlDataSource;
2016

21-
const MAGIC: &[u8] = b"PGCOPY\n\xff\r\n\0";
22-
23-
pub struct Postgres;
24-
2517
impl SqlDataSource for Postgres {
2618
fn get_table_schema(connection_string: &str, table_name: &str) -> crate::error::Result<Schema> {
2719
let (table_schema, table_name) = if table_name.contains(".") {
@@ -82,8 +74,7 @@ impl SqlDataSource for Postgres {
8274
let reader = get_binary_reader(
8375
&mut client,
8476
format!("select * from {} limit {}", table_name, limit).as_str(),
85-
)
86-
.unwrap();
77+
)?;
8778
read_from_binary(reader, &schema).map(|batch| vec![batch])
8879
}
8980

@@ -116,19 +107,16 @@ impl SqlDataSource for Postgres {
116107
let reader = get_binary_reader(
117108
&mut client,
118109
format!("select a.* from ({}) a limit {}", query, limit).as_str(),
119-
)
120-
.unwrap();
110+
)?;
121111
read_from_binary(reader, &schema).map(|batch| vec![batch])
122112
}
123113
}
124114

125115
fn get_binary_reader<'a>(
126116
client: &'a mut Client,
127117
query: &str,
128-
) -> Result<postgres::CopyOutReader<'a>, ()> {
129-
client
130-
.copy_out(format!("COPY ({}) TO stdout with (format binary)", query).as_str())
131-
.map_err(|e| eprintln!("Error: {:?}", e))
118+
) -> crate::error::Result<postgres::CopyOutReader<'a>> {
119+
Ok(client.copy_out(format!("COPY ({}) TO stdout with (format binary)", query).as_str())?)
132120
}
133121

134122
struct PgDataType {
@@ -145,7 +133,7 @@ impl TryFrom<PgDataType> for Field {
145133
type Error = ();
146134
fn try_from(field: PgDataType) -> Result<Self, Self::Error> {
147135
let data_type = match field.data_type.as_str() {
148-
"integer" => match field.numeric_precision {
136+
"int" | "integer" => match field.numeric_precision {
149137
Some(8) => Ok(DataType::Int8),
150138
Some(16) => Ok(DataType::Int16),
151139
Some(32) => Ok(DataType::Int32),
@@ -174,9 +162,11 @@ impl TryFrom<PgDataType> for Field {
174162
"real" => Ok(DataType::Float32),
175163
"smallint" => Ok(DataType::Int16),
176164
"text" => Ok(DataType::Utf8),
177-
"time without time zone" => Ok(DataType::Time64(TimeUnit::Microsecond)), // TODO: use datetime_precision to determine correct type
165+
"time" | "time without time zone" => Ok(DataType::Time64(TimeUnit::Microsecond)), // TODO: use datetime_precision to determine correct type
178166
"timestamp with time zone" => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)),
179-
"timestamp without time zone" => Ok(DataType::Timestamp(TimeUnit::Microsecond, None)),
167+
"timestamp" | "timestamp without time zone" => {
168+
Ok(DataType::Timestamp(TimeUnit::Microsecond, None))
169+
}
180170
"uuid" => Ok(DataType::Binary), // TODO: use a more specialised data type
181171
t @ _ => {
182172
eprintln!("Conversion not set for data type: {:?}", t);
@@ -195,7 +185,6 @@ impl TryFrom<PgDataType> for Field {
195185
///
196186
/// Not all types are covered, but can be easily added
197187
fn pg_to_arrow_type(dt: &Type) -> Option<DataType> {
198-
dbg!(&dt);
199188
match dt {
200189
&Type::BOOL => Some(DataType::Boolean),
201190
&Type::BYTEA | &Type::CHAR | &Type::BPCHAR | &Type::NAME | &Type::TEXT | &Type::VARCHAR => {
@@ -376,7 +365,7 @@ fn row_to_schema(row: &postgres::Row) -> Result<Schema, ()> {
376365

377366
fn read_from_binary<R>(mut reader: R, schema: &Schema) -> crate::error::Result<RecordBatch>
378367
where
379-
R: Read + BufRead,
368+
R: Read,
380369
{
381370
// read signature
382371
let mut bytes = [0u8; 11];
@@ -403,7 +392,7 @@ where
403392
/// Read row tuples
404393
fn read_rows<R>(mut reader: R, schema: &Schema) -> crate::error::Result<RecordBatch>
405394
where
406-
R: Read + BufRead,
395+
R: Read,
407396
{
408397
let mut is_done = false;
409398
let field_len = schema.fields().len();
@@ -450,7 +439,6 @@ where
450439
null_buffers[i].push(false);
451440
} else {
452441
null_buffers[i].push(true);
453-
dbg!((schema.field(i), col_length));
454442
// big endian data, needs to be converted to little endian
455443
let mut data = read_col(
456444
&mut reader,
@@ -695,14 +683,14 @@ fn read_f64<R: Read>(reader: &mut R) -> Result<Vec<u8>, ()> {
695683
fn read_date32<R: Read>(reader: &mut R) -> Result<Vec<u8>, ()> {
696684
reader
697685
.read_i32::<NetworkEndian>()
698-
.map(|v| { 10957 + v }.to_le_bytes().to_vec())
686+
.map(|v| { EPOCH_DAYS + v }.to_le_bytes().to_vec())
699687
.map_err(|e| eprintln!("Error: {:?}", e))
700688
}
701689

702690
fn read_timestamp64<R: Read>(reader: &mut R) -> Result<Vec<u8>, ()> {
703691
reader
704692
.read_i64::<NetworkEndian>()
705-
.map(|v| { 946684800000000 + v }.to_le_bytes().to_vec())
693+
.map(|v| { EPOCH_MICROS + v }.to_le_bytes().to_vec())
706694
.map_err(|e| eprintln!("Error: {:?}", e))
707695
}
708696

0 commit comments

Comments
 (0)