Skip to content

Commit

Permalink
Passing serde test
Browse files Browse the repository at this point in the history
  • Loading branch information
Brent Gardner committed Oct 20, 2022
1 parent a307430 commit 722a056
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 22 deletions.
21 changes: 6 additions & 15 deletions datafusion/core/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ use std::any::Any;
use std::collections::BTreeMap;
use std::{env, error::Error, path::PathBuf, sync::Arc};

use crate::datasource::custom::CustomTable;
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider};
use crate::execution::context::SessionState;
Expand Down Expand Up @@ -331,39 +330,31 @@ pub struct TestTableFactory {}
impl TableProviderFactory for TestTableFactory {
async fn create(
&self,
_name: &str,
url: &str,
) -> datafusion_common::Result<Arc<dyn TableProvider>> {
Ok(Arc::new(CustomTable::new(
"deltatable",
url,
Arc::new(TestTableProvider {}),
)))
Ok(Arc::new(TestTableProvider { url: url.to_string() }))
}

fn with_schema(
&self,
_schema: SchemaRef,
table_type: &str,
url: &str,
) -> datafusion_common::Result<Arc<dyn TableProvider>> {
Ok(Arc::new(CustomTable::new(
table_type,
url,
Arc::new(TestTableProvider {}),
)))
Ok(Arc::new(TestTableProvider { url: url.to_string() }))
}
}

/// TableProvider for testing purposes
pub struct TestTableProvider {}
pub struct TestTableProvider {
pub url: String
}

impl TestTableProvider {}

#[async_trait]
impl TableProvider for TestTableProvider {
fn as_any(&self) -> &dyn Any {
unimplemented!("TestTableProvider is a stub for testing.")
self
}

fn schema(&self) -> SchemaRef {
Expand Down
78 changes: 71 additions & 7 deletions datafusion/proto/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ mod roundtrip_tests {
logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec,
};
use crate::logical_plan::LogicalExtensionCodec;
use arrow::datatypes::Schema;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::{
array::ArrayRef,
datatypes::{
Expand All @@ -65,7 +65,7 @@ mod roundtrip_tests {
use datafusion::prelude::{
create_udf, CsvReadOptions, SessionConfig, SessionContext,
};
use datafusion::test_util::TestTableFactory;
use datafusion::test_util::{TestTableFactory, TestTableProvider};
use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue};
use datafusion_expr::create_udaf;
use datafusion_expr::expr::{Between, BinaryExpr, Case, GroupingSet, Like};
Expand All @@ -81,6 +81,7 @@ mod roundtrip_tests {
use std::fmt::Debug;
use std::fmt::Formatter;
use std::sync::Arc;
use datafusion::datasource::TableProvider;

#[cfg(feature = "json")]
fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) {
Expand Down Expand Up @@ -135,22 +136,73 @@ mod roundtrip_tests {
Ok(())
}

#[derive(Clone, PartialEq, ::prost::Message)]
pub struct TestTableProto {
/// URL of the table root
#[prost(string, tag = "1")]
pub url: String,
}

#[derive(Debug)]
pub struct TestTableProviderCodec {}

impl LogicalExtensionCodec for TestTableProviderCodec {
fn try_decode(&self, buf: &[u8], inputs: &[LogicalPlan], ctx: &SessionContext) -> Result<Extension, DataFusionError> {
Err(DataFusionError::NotImplemented(
"No extension codec provided".to_string(),
))
}

fn try_encode(&self, node: &Extension, buf: &mut Vec<u8>) -> Result<(), DataFusionError> {
Err(DataFusionError::NotImplemented(
"No extension codec provided".to_string(),
))
}

fn try_decode_table_provider(&self, buf: &[u8], schema: SchemaRef, ctx: &SessionContext) -> Result<Arc<dyn TableProvider>, DataFusionError> {
let msg = TestTableProto::decode(buf)
.map_err(|_| DataFusionError::Internal("Error encoding test table".to_string()))?;
let state = ctx.state.read();
let factory = state
.runtime_env
.table_factories
.get("testtable")
.ok_or_else(|| {
DataFusionError::Plan(format!(
"Unable to find testtable factory",
))
})?;
let provider = (*factory).with_schema(schema, msg.url.as_str())?;
Ok(provider)
}

fn try_encode_table_provider(&self, node: Arc<dyn TableProvider>, buf: &mut Vec<u8>) -> Result<(), DataFusionError> {
let table = node.as_ref().as_any().downcast_ref::<TestTableProvider>()
.ok_or(DataFusionError::Internal("Can't encode non-test tables".to_string()))?;
let msg = TestTableProto {
url: table.url.clone()
};
msg.encode(buf).map_err(|_| DataFusionError::Internal("Error encoding test table".to_string()))
}
}

#[tokio::test]
async fn roundtrip_custom_tables() -> Result<(), DataFusionError> {
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("deltatable".to_string(), Arc::new(TestTableFactory {}));
table_factories.insert("testtable".to_string(), Arc::new(TestTableFactory {}));
let cfg = RuntimeConfig::new().with_table_factories(table_factories);
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
let ctx = SessionContext::with_config_rt(ses, Arc::new(env));

let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';";
let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';";
ctx.sql(sql).await.unwrap();

let scan = ctx.table("dt")?.to_logical_plan()?;
let bytes = logical_plan_to_bytes(&scan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
let codec = TestTableProviderCodec {};
let scan = ctx.table("t")?.to_logical_plan()?;
let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?;
let logical_round_trip = logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?;
assert_eq!(format!("{:?}", scan), format!("{:?}", logical_round_trip));
Ok(())
}
Expand Down Expand Up @@ -350,6 +402,18 @@ mod roundtrip_tests {
))
}
}

fn try_decode_table_provider(&self, buf: &[u8], schema: SchemaRef, ctx: &SessionContext) -> Result<Arc<dyn TableProvider>, DataFusionError> {
Err(DataFusionError::Internal(
"unsupported plan type".to_string(),
))
}

fn try_encode_table_provider(&self, node: Arc<dyn TableProvider>, buf: &mut Vec<u8>) -> Result<(), DataFusionError> {
Err(DataFusionError::Internal(
"unsupported plan type".to_string(),
))
}
}

#[test]
Expand Down

0 comments on commit 722a056

Please sign in to comment.