diff --git a/datafusion/core/src/test_util.rs b/datafusion/core/src/test_util.rs index 03adc33fbfd66..a60f318a54abe 100644 --- a/datafusion/core/src/test_util.rs +++ b/datafusion/core/src/test_util.rs @@ -21,7 +21,6 @@ use std::any::Any; use std::collections::BTreeMap; use std::{env, error::Error, path::PathBuf, sync::Arc}; -use crate::datasource::custom::CustomTable; use crate::datasource::datasource::TableProviderFactory; use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider}; use crate::execution::context::SessionState; @@ -331,39 +330,31 @@ pub struct TestTableFactory {} impl TableProviderFactory for TestTableFactory { async fn create( &self, - _name: &str, url: &str, ) -> datafusion_common::Result> { - Ok(Arc::new(CustomTable::new( - "deltatable", - url, - Arc::new(TestTableProvider {}), - ))) + Ok(Arc::new(TestTableProvider { url: url.to_string() })) } fn with_schema( &self, _schema: SchemaRef, - table_type: &str, url: &str, ) -> datafusion_common::Result> { - Ok(Arc::new(CustomTable::new( - table_type, - url, - Arc::new(TestTableProvider {}), - ))) + Ok(Arc::new(TestTableProvider { url: url.to_string() })) } } /// TableProvider for testing purposes -pub struct TestTableProvider {} +pub struct TestTableProvider { + pub url: String +} impl TestTableProvider {} #[async_trait] impl TableProvider for TestTableProvider { fn as_any(&self) -> &dyn Any { - unimplemented!("TestTableProvider is a stub for testing.") + self } fn schema(&self) -> SchemaRef { diff --git a/datafusion/proto/src/lib.rs b/datafusion/proto/src/lib.rs index 624c801730acb..92634bac436a6 100644 --- a/datafusion/proto/src/lib.rs +++ b/datafusion/proto/src/lib.rs @@ -51,7 +51,7 @@ mod roundtrip_tests { logical_plan_to_bytes, logical_plan_to_bytes_with_extension_codec, }; use crate::logical_plan::LogicalExtensionCodec; - use arrow::datatypes::Schema; + use arrow::datatypes::{Schema, SchemaRef}; use arrow::{ array::ArrayRef, datatypes::{ @@ -65,7 +65,7 @@ mod roundtrip_tests { use datafusion::prelude::{ create_udf, CsvReadOptions, SessionConfig, SessionContext, }; - use datafusion::test_util::TestTableFactory; + use datafusion::test_util::{TestTableFactory, TestTableProvider}; use datafusion_common::{DFSchemaRef, DataFusionError, ScalarValue}; use datafusion_expr::create_udaf; use datafusion_expr::expr::{Between, BinaryExpr, Case, GroupingSet, Like}; @@ -81,6 +81,7 @@ mod roundtrip_tests { use std::fmt::Debug; use std::fmt::Formatter; use std::sync::Arc; + use datafusion::datasource::TableProvider; #[cfg(feature = "json")] fn roundtrip_json_test(proto: &protobuf::LogicalExprNode) { @@ -135,22 +136,73 @@ mod roundtrip_tests { Ok(()) } + #[derive(Clone, PartialEq, ::prost::Message)] + pub struct TestTableProto { + /// URL of the table root + #[prost(string, tag = "1")] + pub url: String, + } + + #[derive(Debug)] + pub struct TestTableProviderCodec {} + + impl LogicalExtensionCodec for TestTableProviderCodec { + fn try_decode(&self, buf: &[u8], inputs: &[LogicalPlan], ctx: &SessionContext) -> Result { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } + + fn try_encode(&self, node: &Extension, buf: &mut Vec) -> Result<(), DataFusionError> { + Err(DataFusionError::NotImplemented( + "No extension codec provided".to_string(), + )) + } + + fn try_decode_table_provider(&self, buf: &[u8], schema: SchemaRef, ctx: &SessionContext) -> Result, DataFusionError> { + let msg = TestTableProto::decode(buf) + .map_err(|_| DataFusionError::Internal("Error encoding test table".to_string()))?; + let state = ctx.state.read(); + let factory = state + .runtime_env + .table_factories + .get("testtable") + .ok_or_else(|| { + DataFusionError::Plan(format!( + "Unable to find testtable factory", + )) + })?; + let provider = (*factory).with_schema(schema, msg.url.as_str())?; + Ok(provider) + } + + fn try_encode_table_provider(&self, node: Arc, buf: &mut Vec) -> Result<(), DataFusionError> { + let table = node.as_ref().as_any().downcast_ref::() + .ok_or(DataFusionError::Internal("Can't encode non-test tables".to_string()))?; + let msg = TestTableProto { + url: table.url.clone() + }; + msg.encode(buf).map_err(|_| DataFusionError::Internal("Error encoding test table".to_string())) + } + } + #[tokio::test] async fn roundtrip_custom_tables() -> Result<(), DataFusionError> { let mut table_factories: HashMap> = HashMap::new(); - table_factories.insert("deltatable".to_string(), Arc::new(TestTableFactory {})); + table_factories.insert("testtable".to_string(), Arc::new(TestTableFactory {})); let cfg = RuntimeConfig::new().with_table_factories(table_factories); let env = RuntimeEnv::new(cfg).unwrap(); let ses = SessionConfig::new(); let ctx = SessionContext::with_config_rt(ses, Arc::new(env)); - let sql = "CREATE EXTERNAL TABLE dt STORED AS DELTATABLE LOCATION 's3://bucket/schema/table';"; + let sql = "CREATE EXTERNAL TABLE t STORED AS testtable LOCATION 's3://bucket/schema/table';"; ctx.sql(sql).await.unwrap(); - let scan = ctx.table("dt")?.to_logical_plan()?; - let bytes = logical_plan_to_bytes(&scan)?; - let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?; + let codec = TestTableProviderCodec {}; + let scan = ctx.table("t")?.to_logical_plan()?; + let bytes = logical_plan_to_bytes_with_extension_codec(&scan, &codec)?; + let logical_round_trip = logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &codec)?; assert_eq!(format!("{:?}", scan), format!("{:?}", logical_round_trip)); Ok(()) } @@ -350,6 +402,18 @@ mod roundtrip_tests { )) } } + + fn try_decode_table_provider(&self, buf: &[u8], schema: SchemaRef, ctx: &SessionContext) -> Result, DataFusionError> { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } + + fn try_encode_table_provider(&self, node: Arc, buf: &mut Vec) -> Result<(), DataFusionError> { + Err(DataFusionError::Internal( + "unsupported plan type".to_string(), + )) + } } #[test]