Skip to content

Commit

Permalink
chore: rust sdk testing and interface (#3195)
Browse files Browse the repository at this point in the history
  • Loading branch information
tychoish authored Sep 23, 2024
1 parent 972e6cc commit 1359ddd
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 4 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion crates/glaredb/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ workspace = true
[lib]
# no tests currently implemented in this package; skip, then
doctest = false
test = false

[dependencies]
sqlexec = { path = "../sqlexec" }
Expand All @@ -23,3 +22,6 @@ anyhow = { workspace = true }
thiserror = { workspace = true }
derive_builder = "0.20.1"
indexmap = "2.5.0"

[dev-dependencies]
tokio = { workspace = true }
140 changes: 137 additions & 3 deletions crates/glaredb/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ pub struct ConnectOptions {
/// embedded cases so that queries can bindings to extract tables
/// from variables in the binding's scope that data frames, or the
/// output of a query.
#[builder(default = "None")]
#[builder(setter(strip_option))]
pub environment_reader: Option<Arc<dyn EnvironmentReader>>,
}
Expand Down Expand Up @@ -405,6 +406,15 @@ impl RowMapBatch {
}
}


pub type RowStream<'a> = Pin<Box<dyn Stream<Item = Result<RowMap, DatabaseError>> + 'a>>;

impl<'a> From<&'a mut RecordStream> for RowStream<'a> {
fn from(val: &'a mut RecordStream) -> Self {
val.to_row_stream()
}
}

impl RecordStream {
// Collects all of the record batches in a stream, aborting if
// there are any errors.
Expand All @@ -415,9 +425,8 @@ impl RecordStream {

/// Collects all of the record batches and rotates the results for
/// a map-based row-oriented format.
pub async fn to_rows(&mut self) -> Result<Vec<RowMapBatch>, DatabaseError> {
let stream = &mut self.0;
stream.map(RowMapBatch::try_from).try_collect().await
pub async fn to_rows(&mut self) -> Result<Vec<RowMap>, DatabaseError> {
self.to_row_stream().try_collect().await
}

// Iterates through the stream, ensuring propagating any errors,
Expand All @@ -431,6 +440,21 @@ impl RecordStream {

Ok(())
}

pub fn to_row_stream(&mut self) -> RowStream<'_> {
let stream = &mut self.0;
stream::once(async move {
stream
.map(|v| RowMapBatch::try_from(v?).map_err(DatabaseError::from))
.map(|v| match v {
Ok(batch) => stream::iter(batch.iter().map(Ok)).boxed(),
Err(e) => stream::once(async move { Err(e) }).boxed(),
})
.flatten()
})
.flatten()
.boxed()
}
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -813,3 +837,113 @@ impl From<DatabaseError> for DataFusionError {
}
}
}

#[cfg(test)]
mod test {
use super::*;

async fn db() -> Result<Connection, DatabaseError> {
let conn = ConnectOptionsBuilder::new_in_memory()
.build()?
.connect()
.await?;

let _res = conn
.execute("CREATE TEMP TABLE test_fixture (title text, total int);")
.evaluate()
.await?
.call();

Ok(conn)
}

#[tokio::test]
async fn simple_query_check() {
db().await
.unwrap()
.sql("select 1;")
.evaluate()
.await
.unwrap()
.call()
.check()
.await
.unwrap();
}

#[tokio::test]
async fn simple_query_results() {
let mut op = db()
.await
.unwrap()
.sql("select 1;")
.evaluate()
.await
.unwrap()
.call();

let mut stream = op.to_row_stream();

let mut res = Vec::new();
while let Some(b) = stream.next().await {
res.push(b.unwrap());
}

assert_eq!(res.len(), 1);
let row = res.get(0).unwrap();
assert_eq!(row.len(), 1);
}

#[tokio::test]
async fn longer_local_query() {
let db = db().await.unwrap();
let results = db
.execute("INSERT INTO test_fixture VALUES ('a', 1), ('b', 42)")
.evaluate()
.await
.unwrap()
.call()
.to_rows()
.await
.unwrap();

assert_eq!(results.len(), 1);
let row = results.get(0).unwrap().to_owned();
assert_eq!(
row.get(&"count".to_string()).unwrap().to_owned(),
ScalarValue::UInt64(Some(2))
);

let results = db
.sql("SELECT * FROM test_fixture")
.evaluate()
.await
.unwrap()
.call()
.to_rows()
.await
.unwrap();

assert_eq!(results.len(), 2);
let first = results.get(0).unwrap().to_owned();
let second = results.get(1).unwrap().to_owned();

assert_eq!(
first.get(&"title".to_string()).unwrap().to_owned(),
ScalarValue::new_utf8("a")
);
assert_eq!(
first.get(&"total".to_string()).unwrap().to_owned(),
ScalarValue::Int32(Some(1))
);

assert_eq!(
second.get(&"title".to_string()).unwrap().to_owned(),
ScalarValue::new_utf8("b")
);
assert_eq!(
second.get(&"total".to_string()).unwrap().to_owned(),
ScalarValue::Int32(Some(42))
);
}
}

0 comments on commit 1359ddd

Please sign in to comment.