Skip to content

Commit a483a82

Browse files
committed
WIP - add rdbc-tokio-postgres as first impl of new traits
1 parent 2adb34a commit a483a82

File tree

4 files changed

+238
-22
lines changed

4 files changed

+238
-22
lines changed

Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ members = [
55
# "rdbc-mysql",
66
# "rdbc-postgres",
77
# "rdbc-sqlite",
8+
"rdbc-tokio-postgres",
89
# "rdbc-odbc",
910
"rdbc-cli",
1011
]

rdbc-tokio-postgres/Cargo.toml

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
[package]
2+
name = "rdbc-tokio-postgres"
3+
version = "0.1.0"
4+
authors = ["Ben Sully <ben@bsull.io>"]
5+
edition = "2018"
6+
7+
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
8+
9+
[dependencies]
10+
async-trait = "0.1.22"
11+
rdbc = { path = "../rdbc", version = "0.1.6" }
12+
sqlparser = "0.5.0"
13+
tokio = "0.2.10"
14+
tokio-postgres = "0.5.1"

rdbc-tokio-postgres/src/lib.rs

Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
use std::{pin::Pin, sync::Arc};
2+
3+
use async_trait::async_trait;
4+
use sqlparser::{
5+
dialect::PostgreSqlDialect,
6+
tokenizer::{Token, Tokenizer, Word},
7+
};
8+
use tokio::stream::Stream;
9+
use tokio_postgres::{types::Type, Client, NoTls, Row, Statement};
10+
11+
#[derive(Debug)]
12+
pub enum Error {
13+
TokioPostgres(tokio_postgres::Error),
14+
}
15+
16+
impl From<tokio_postgres::Error> for Error {
17+
fn from(other: tokio_postgres::Error) -> Self {
18+
Self::TokioPostgres(other)
19+
}
20+
}
21+
22+
pub struct TokioPostgresDriver;
23+
24+
#[async_trait]
25+
impl rdbc::Driver for TokioPostgresDriver {
26+
type Connection = TokioPostgresConnection;
27+
type Error = Error;
28+
29+
async fn connect(url: &str) -> Result<Self::Connection, Self::Error> {
30+
let (client, conn) = tokio_postgres::connect(url, NoTls).await?;
31+
tokio::spawn(conn);
32+
Ok(TokioPostgresConnection {
33+
inner: Arc::new(client),
34+
})
35+
}
36+
}
37+
38+
pub struct TokioPostgresConnection {
39+
inner: Arc<Client>,
40+
}
41+
42+
#[async_trait]
43+
impl rdbc::Connection for TokioPostgresConnection {
44+
type Statement = TokioPostgresStatement;
45+
type Error = Error;
46+
47+
async fn create(&mut self, sql: &str) -> Result<Self::Statement, Self::Error> {
48+
let sql = {
49+
let dialect = PostgreSqlDialect {};
50+
let mut tokenizer = Tokenizer::new(&dialect, sql);
51+
let tokens = tokenizer.tokenize().unwrap();
52+
let mut i = 0_usize;
53+
let tokens: Vec<Token> = tokens
54+
.iter()
55+
.map(|t| match t {
56+
Token::Char(c) if *c == '?' => {
57+
i += 1;
58+
Token::Word(Word {
59+
value: format!("${}", i),
60+
quote_style: None,
61+
keyword: "".to_owned(),
62+
})
63+
}
64+
_ => t.clone(),
65+
})
66+
.collect();
67+
tokens
68+
.iter()
69+
.map(|t| format!("{}", t))
70+
.collect::<Vec<String>>()
71+
.join("")
72+
};
73+
let statement = self.inner.prepare(&sql).await?;
74+
Ok(TokioPostgresStatement {
75+
client: Arc::clone(&self.inner),
76+
statement,
77+
})
78+
}
79+
80+
async fn prepare(&mut self, sql: &str) -> Result<Self::Statement, Self::Error> {
81+
self.create(sql).await
82+
}
83+
}
84+
85+
pub struct TokioPostgresStatement {
86+
client: Arc<Client>,
87+
statement: Statement,
88+
}
89+
90+
fn to_rdbc_type(ty: &Type) -> rdbc::DataType {
91+
match ty {
92+
&Type::BOOL => rdbc::DataType::Bool,
93+
&Type::CHAR => rdbc::DataType::Char,
94+
//TODO all types
95+
_ => rdbc::DataType::Utf8,
96+
}
97+
}
98+
99+
fn to_postgres_params(values: &[rdbc::Value]) -> Vec<Box<dyn tokio_postgres::types::ToSql + Sync>> {
100+
values
101+
.iter()
102+
.map(|v| match v {
103+
rdbc::Value::String(s) => {
104+
Box::new(s.clone()) as Box<dyn tokio_postgres::types::ToSql + Sync>
105+
}
106+
rdbc::Value::Int32(n) => Box::new(*n) as Box<dyn tokio_postgres::types::ToSql + Sync>,
107+
rdbc::Value::UInt32(n) => Box::new(*n) as Box<dyn tokio_postgres::types::ToSql + Sync>, //TODO all types
108+
})
109+
.collect()
110+
}
111+
112+
#[async_trait]
113+
impl rdbc::Statement for TokioPostgresStatement {
114+
type ResultSet = TokioPostgresResultSet;
115+
type Error = Error;
116+
async fn execute_query(
117+
&mut self,
118+
params: &[rdbc::Value],
119+
) -> Result<Self::ResultSet, Self::Error> {
120+
let params = to_postgres_params(params);
121+
let params: Vec<_> = params.into_iter().map(|p| p.as_ref()).collect();
122+
let rows = self
123+
.client
124+
.query(&self.statement, params.as_slice())
125+
.await?
126+
.into_iter()
127+
.map(|row| TokioPostgresRow { inner: row })
128+
.collect();
129+
let meta = self
130+
.statement
131+
.columns()
132+
.iter()
133+
.map(|c| rdbc::Column::new(c.name(), to_rdbc_type(c.type_())))
134+
.collect();
135+
Ok(TokioPostgresResultSet { rows, meta })
136+
}
137+
async fn execute_update(&mut self, params: &[rdbc::Value]) -> Result<u64, Self::Error> {
138+
todo!()
139+
}
140+
}
141+
142+
pub struct TokioPostgresResultSet {
143+
meta: Vec<rdbc::Column>,
144+
rows: Vec<TokioPostgresRow>,
145+
}
146+
147+
#[async_trait]
148+
impl rdbc::ResultSet for TokioPostgresResultSet {
149+
type MetaData = Vec<rdbc::Column>;
150+
type Row = TokioPostgresRow;
151+
type Error = Error;
152+
153+
fn meta_data(&self) -> Result<&Self::MetaData, Self::Error> {
154+
Ok(&self.meta)
155+
}
156+
157+
async fn batches(
158+
&mut self,
159+
) -> Result<Pin<Box<dyn Stream<Item = Vec<Self::Row>>>>, Self::Error> {
160+
let rows = std::mem::take(&mut self.rows);
161+
Ok(Box::pin(tokio::stream::once(rows)))
162+
}
163+
}
164+
165+
pub struct TokioPostgresRow {
166+
inner: Row,
167+
}
168+
169+
macro_rules! impl_resultset_fns {
170+
($($fn: ident -> $ty: ty),*) => {
171+
$(
172+
fn $fn(&self, i: u64) -> Result<Option<$ty>, Self::Error> {
173+
Some(self.inner.try_get((i - 1) as usize)).transpose().map_err(Into::into)
174+
}
175+
)*
176+
}
177+
}
178+
179+
impl rdbc::Row for TokioPostgresRow {
180+
type Error = Error;
181+
impl_resultset_fns! {
182+
get_i8 -> i8,
183+
get_i16 -> i16,
184+
get_i32 -> i32,
185+
get_i64 -> i64,
186+
get_f32 -> f32,
187+
get_f64 -> f64,
188+
get_string -> String,
189+
get_bytes -> Vec<u8>
190+
}
191+
}

rdbc/src/lib.rs

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -49,42 +49,48 @@ impl ToString for Value {
4949
}
5050
}
5151

52-
/// RDBC Result type
53-
pub type Result<T> = std::result::Result<T, Error>;
54-
5552
/// Represents database driver that can be shared between threads, and can therefore implement
5653
/// a connection pool
54+
#[async_trait]
5755
pub trait Driver: Sync + Send {
5856
/// The type of connection created by this driver.
5957
type Connection: Connection;
6058

59+
type Error;
60+
6161
/// Create a connection to the database. Note that connections are intended to be used
6262
/// in a single thread since most database connections are not thread-safe
63-
fn connect(url: &str) -> Result<Self::Connection>;
63+
async fn connect(url: &str) -> Result<Self::Connection, Self::Error>;
6464
}
6565

6666
/// Represents a connection to a database
67+
#[async_trait]
6768
pub trait Connection {
6869
/// The type of statement produced by this connection.
6970
type Statement: Statement;
7071

72+
type Error;
73+
7174
/// Create a statement for execution
72-
fn create(&mut self, sql: &str) -> Result<Self::Statement>;
75+
async fn create(&mut self, sql: &str) -> Result<Self::Statement, Self::Error>;
7376

7477
/// Create a prepared statement for execution
75-
fn prepare(&mut self, sql: &str) -> Result<Self::Statement>;
78+
async fn prepare(&mut self, sql: &str) -> Result<Self::Statement, Self::Error>;
7679
}
7780

7881
/// Represents an executable statement
82+
#[async_trait]
7983
pub trait Statement {
8084
/// The type of ResultSet returned by this statement.
8185
type ResultSet: ResultSet;
8286

87+
type Error;
88+
8389
/// Execute a query that is expected to return a result set, such as a `SELECT` statement
84-
fn execute_query(&mut self, params: &[Value]) -> Result<Self::ResultSet>;
90+
async fn execute_query(&mut self, params: &[Value]) -> Result<Self::ResultSet, Self::Error>;
8591

8692
/// Execute a query that is expected to update some rows.
87-
fn execute_update(&mut self, params: &[Value]) -> Result<u64>;
93+
async fn execute_update(&mut self, params: &[Value]) -> Result<u64, Self::Error>;
8894
}
8995

9096
/// Result set from executing a query against a statement
@@ -95,36 +101,40 @@ pub trait ResultSet {
95101
/// The type of row included in this result set.
96102
type Row: Row;
97103

104+
type Error;
105+
98106
/// get meta data about this result set
99-
fn meta_data(&self) -> Result<Self::MetaData>;
107+
fn meta_data(&self) -> Result<&Self::MetaData, Self::Error>;
100108

101109
/// Get a stream where each item is a batch of rows.
102-
async fn batches(&mut self) -> Result<Pin<Box<dyn Stream<Item = Vec<Self::Row>>>>>;
110+
async fn batches(&mut self)
111+
-> Result<Pin<Box<dyn Stream<Item = Vec<Self::Row>>>>, Self::Error>;
103112

104113
/// Get a stream of rows.
105114
///
106115
/// Note that the rows are actually returned from the database in batches;
107116
/// this just flattens the batches to provide a (possibly) simpler API.
108-
async fn rows<'a>(&'a mut self) -> Result<Box<dyn Stream<Item = Self::Row> + 'a>> {
117+
async fn rows<'a>(&'a mut self) -> Result<Box<dyn Stream<Item = Self::Row> + 'a>, Self::Error> {
109118
Ok(Box::new(self.batches().await?.map(iter).flatten()))
110119
}
111120
}
112121

113122
pub trait Row {
114-
fn get_i8(&self, i: u64) -> Result<Option<i8>>;
115-
fn get_i16(&self, i: u64) -> Result<Option<i16>>;
116-
fn get_i32(&self, i: u64) -> Result<Option<i32>>;
117-
fn get_i64(&self, i: u64) -> Result<Option<i64>>;
118-
fn get_f32(&self, i: u64) -> Result<Option<f32>>;
119-
fn get_f64(&self, i: u64) -> Result<Option<f64>>;
120-
fn get_string(&self, i: u64) -> Result<Option<String>>;
121-
fn get_bytes(&self, i: u64) -> Result<Option<Vec<u8>>>;
123+
type Error;
124+
fn get_i8(&self, i: u64) -> Result<Option<i8>, Self::Error>;
125+
fn get_i16(&self, i: u64) -> Result<Option<i16>, Self::Error>;
126+
fn get_i32(&self, i: u64) -> Result<Option<i32>, Self::Error>;
127+
fn get_i64(&self, i: u64) -> Result<Option<i64>, Self::Error>;
128+
fn get_f32(&self, i: u64) -> Result<Option<f32>, Self::Error>;
129+
fn get_f64(&self, i: u64) -> Result<Option<f64>, Self::Error>;
130+
fn get_string(&self, i: u64) -> Result<Option<String>, Self::Error>;
131+
fn get_bytes(&self, i: u64) -> Result<Option<Vec<u8>>, Self::Error>;
122132
}
123133

124134
/// Meta data for result set
125135
pub trait MetaData {
126136
fn num_columns(&self) -> u64;
127-
fn column_name(&self, i: u64) -> String;
137+
fn column_name(&self, i: u64) -> &str;
128138
fn column_type(&self, i: u64) -> DataType;
129139
}
130140

@@ -166,8 +176,8 @@ impl MetaData for Vec<Column> {
166176
self.len() as u64
167177
}
168178

169-
fn column_name(&self, i: u64) -> String {
170-
self[i as usize].name.clone()
179+
fn column_name(&self, i: u64) -> &str {
180+
&self[i as usize].name
171181
}
172182

173183
fn column_type(&self, i: u64) -> DataType {

0 commit comments

Comments
 (0)