Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add serde for plans with tables from TableProviderFactorys #3907

Merged
merged 10 commits into from
Oct 24, 2022
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,6 @@ pub trait TableProvider: Sync + Send {
/// from a directory of files only when that name is referenced.
#[async_trait]
pub trait TableProviderFactory: Sync + Send {
/// Create a TableProvider given name and url
async fn create(&self, name: &str, url: &str) -> Result<Arc<dyn TableProvider>>;
/// Create a TableProvider with the given url
async fn create(&self, url: &str) -> Result<Arc<dyn TableProvider>>;
}
7 changes: 3 additions & 4 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -418,19 +418,18 @@ impl SessionContext {
cmd: &CreateExternalTable,
) -> Result<Arc<DataFrame>> {
let state = self.state.read().clone();
let file_type = cmd.file_type.to_lowercase();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If there is an assumption that the file_types are always lower case in table_factories perhaps would it make sense to update the comment to that effect?

https://github.com/apache/arrow-datafusion/blob/6e0097d35391fea0d57c1d2ecfdef18437f681f4/datafusion/core/src/execution/runtime_env.rs#L48

let factory = &state
.runtime_env
.table_factories
.get(&cmd.file_type)
.get(file_type.as_str())
.ok_or_else(|| {
DataFusionError::Execution(format!(
"Unable to find factory for {}",
cmd.file_type
))
})?;
let table = (*factory)
.create(cmd.name.as_str(), cmd.location.as_str())
.await?;
let table = (*factory).create(cmd.location.as_str()).await?;
self.register_table(cmd.name.as_str(), table)?;
let plan = LogicalPlanBuilder::empty(false).build()?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
Expand Down
60 changes: 59 additions & 1 deletion datafusion/core/src/test_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,19 @@

//! Utility functions to make testing DataFusion based crates easier

use std::any::Any;
use std::collections::BTreeMap;
use std::{env, error::Error, path::PathBuf, sync::Arc};

use crate::datasource::{empty::EmptyTable, provider_as_source};
use crate::datasource::datasource::TableProviderFactory;
use crate::datasource::{empty::EmptyTable, provider_as_source, TableProvider};
use crate::execution::context::SessionState;
use crate::logical_expr::{LogicalPlanBuilder, UNNAMED_TABLE};
use crate::physical_plan::ExecutionPlan;
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use async_trait::async_trait;
use datafusion_common::DataFusionError;
use datafusion_expr::{Expr, TableType};

/// Compares formatted output of a record batch with an expected
/// vector of strings, with the result of pretty formatting record
Expand Down Expand Up @@ -317,6 +323,58 @@ pub fn aggr_test_schema_with_missing_col() -> SchemaRef {
Arc::new(schema)
}

/// TableFactory for tests
pub struct TestTableFactory {}

#[async_trait]
impl TableProviderFactory for TestTableFactory {
async fn create(
&self,
url: &str,
) -> datafusion_common::Result<Arc<dyn TableProvider>> {
Ok(Arc::new(TestTableProvider {
url: url.to_string(),
}))
}
}

/// TableProvider for testing purposes
pub struct TestTableProvider {
/// URL of table files or folder
pub url: String,
}

impl TestTableProvider {}

#[async_trait]
impl TableProvider for TestTableProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
let schema = Schema::new(vec![
Field::new("a", DataType::Int64, true),
Field::new("b", DataType::Decimal128(15, 2), true),
]);
Arc::new(schema)
}

fn table_type(&self) -> TableType {
unimplemented!("TestTableProvider is a stub for testing.")
}

async fn scan(
&self,
_ctx: &SessionState,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> datafusion_common::Result<Arc<dyn ExecutionPlan>> {
unimplemented!("TestTableProvider is a stub for testing.")
}
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
45 changes: 2 additions & 43 deletions datafusion/core/tests/sql/create_drop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@
// specific language governing permissions and limitations
// under the License.

use async_trait::async_trait;
use std::any::Any;
use std::collections::HashMap;
use std::io::Write;

use datafusion::datasource::datasource::TableProviderFactory;
use datafusion::execution::context::SessionState;
use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv};
use datafusion_expr::TableType;
use datafusion::test_util::TestTableFactory;
use tempfile::TempDir;

use super::*;
Expand Down Expand Up @@ -369,49 +366,11 @@ async fn create_pipe_delimited_csv_table() -> Result<()> {
Ok(())
}

struct TestTableProvider {}

impl TestTableProvider {}

#[async_trait]
impl TableProvider for TestTableProvider {
fn as_any(&self) -> &dyn Any {
unimplemented!("TestTableProvider is a stub for testing.")
}

fn schema(&self) -> SchemaRef {
unimplemented!("TestTableProvider is a stub for testing.")
}

fn table_type(&self) -> TableType {
unimplemented!("TestTableProvider is a stub for testing.")
}

async fn scan(
&self,
_ctx: &SessionState,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
unimplemented!("TestTableProvider is a stub for testing.")
}
}

struct TestTableFactory {}

#[async_trait]
impl TableProviderFactory for TestTableFactory {
async fn create(&self, _name: &str, _url: &str) -> Result<Arc<dyn TableProvider>> {
Ok(Arc::new(TestTableProvider {}))
}
}

#[tokio::test]
async fn create_custom_table() -> Result<()> {
let mut table_factories: HashMap<String, Arc<dyn TableProviderFactory>> =
HashMap::new();
table_factories.insert("DELTATABLE".to_string(), Arc::new(TestTableFactory {}));
table_factories.insert("deltatable".to_string(), Arc::new(TestTableFactory {}));
let cfg = RuntimeConfig::new().with_table_factories(table_factories);
let env = RuntimeEnv::new(cfg).unwrap();
let ses = SessionConfig::new();
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ json = ["pbjson", "pbjson-build", "serde", "serde_json"]

[dependencies]
arrow = "25.0.0"
async-trait = "0.1.41"
datafusion = { path = "../core", version = "13.0.0" }
datafusion-common = { path = "../common", version = "13.0.0" }
datafusion-expr = { path = "../expr", version = "13.0.0" }
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ async fn main() -> Result<()> {
?;
let plan = ctx.table("t1")?.to_logical_plan()?;
let bytes = logical_plan_to_bytes(&plan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?;
assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
Ok(())
}
Expand Down
2 changes: 1 addition & 1 deletion datafusion/proto/examples/plan_serde.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ async fn main() -> Result<()> {
.await?;
let plan = ctx.table("t1")?.to_logical_plan()?;
let bytes = logical_plan_to_bytes(&plan)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx)?;
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).await?;
assert_eq!(format!("{:?}", plan), format!("{:?}", logical_round_trip));
Ok(())
}
10 changes: 10 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ message LogicalPlanNode {
CreateViewNode create_view = 22;
DistinctNode distinct = 23;
ViewTableScanNode view_scan = 24;
CustomTableScanNode custom_scan = 25;
}
}

Expand Down Expand Up @@ -118,6 +119,15 @@ message ViewTableScanNode {
string definition = 5;
}

// Logical Plan to Scan a CustomTableProvider registered at runtime
message CustomTableScanNode {
string table_name = 1;
ProjectionColumns projection = 2;
datafusion.Schema schema = 3;
repeated datafusion.LogicalExprNode filters = 4;
bytes custom_table_data = 5;
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The above 4 fields are unrelated to the TableProvider - they come from the scan. The ListingTableScan doesn't have this problem because it can happily combine the two logical units into one message, but unless we want each custom table type to reimplement this logic, we should probably keep it here.

}

message ProjectionNode {
LogicalPlanNode input = 1;
repeated datafusion.LogicalExprNode expr = 2;
Expand Down
47 changes: 38 additions & 9 deletions datafusion/proto/src/bytes/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,16 @@
//! Serialization / Deserialization to Bytes
use crate::logical_plan::{AsLogicalPlan, LogicalExtensionCodec};
use crate::{from_proto::parse_expr, protobuf};
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
use datafusion::datasource::TableProvider;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::{Expr, Extension, LogicalPlan};
use prost::{
bytes::{Bytes, BytesMut},
Message,
};
use std::sync::Arc;

// Reexport Bytes which appears in the API
use datafusion::execution::registry::FunctionRegistry;
Expand Down Expand Up @@ -132,37 +136,41 @@ pub fn logical_plan_to_bytes_with_extension_codec(

/// Deserialize a LogicalPlan from json
#[cfg(feature = "json")]
pub fn logical_plan_from_json(json: &str, ctx: &SessionContext) -> Result<LogicalPlan> {
pub async fn logical_plan_from_json(
json: &str,
ctx: &SessionContext,
) -> Result<LogicalPlan> {
let back: protobuf::LogicalPlanNode = serde_json::from_str(json)
.map_err(|e| DataFusionError::Plan(format!("Error serializing plan: {}", e)))?;
let extension_codec = DefaultExtensionCodec {};
back.try_into_logical_plan(ctx, &extension_codec)
back.try_into_logical_plan(ctx, &extension_codec).await
}

/// Deserialize a LogicalPlan from bytes
pub fn logical_plan_from_bytes(
pub async fn logical_plan_from_bytes(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes sense to me that these functions must become async in order to (potentially) instantiate a table provider 👍

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To follow up -- becomes async so that a table provider can be instantiated. This table provider, such as delta-rs, might have to do remote IO

bytes: &[u8],
ctx: &SessionContext,
) -> Result<LogicalPlan> {
let extension_codec = DefaultExtensionCodec {};
logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec)
logical_plan_from_bytes_with_extension_codec(bytes, ctx, &extension_codec).await
}

/// Deserialize a LogicalPlan from bytes
pub fn logical_plan_from_bytes_with_extension_codec(
pub async fn logical_plan_from_bytes_with_extension_codec(
bytes: &[u8],
ctx: &SessionContext,
extension_codec: &dyn LogicalExtensionCodec,
) -> Result<LogicalPlan> {
let protobuf = protobuf::LogicalPlanNode::decode(bytes).map_err(|e| {
DataFusionError::Plan(format!("Error decoding expr as protobuf: {}", e))
})?;
protobuf.try_into_logical_plan(ctx, extension_codec)
protobuf.try_into_logical_plan(ctx, extension_codec).await
}

#[derive(Debug)]
struct DefaultExtensionCodec {}

#[async_trait]
impl LogicalExtensionCodec for DefaultExtensionCodec {
fn try_decode(
&self,
Expand All @@ -180,6 +188,27 @@ impl LogicalExtensionCodec for DefaultExtensionCodec {
"No extension codec provided".to_string(),
))
}

async fn try_decode_table_provider(
&self,
_buf: &[u8],
_schema: SchemaRef,
_ctx: &SessionContext,
) -> std::result::Result<Arc<dyn TableProvider>, DataFusionError> {
Err(DataFusionError::NotImplemented(
"No extension codec provided".to_string(),
avantgardnerio marked this conversation as resolved.
Show resolved Hide resolved
))
}

fn try_encode_table_provider(
&self,
_node: Arc<dyn TableProvider>,
_buf: &mut Vec<u8>,
) -> std::result::Result<(), DataFusionError> {
Err(DataFusionError::NotImplemented(
"No extension codec provided".to_string(),
avantgardnerio marked this conversation as resolved.
Show resolved Hide resolved
))
}
}

#[cfg(test)]
Expand Down Expand Up @@ -214,12 +243,12 @@ mod test {
assert_eq!(actual, expected);
}

#[test]
#[tokio::test]
#[cfg(feature = "json")]
fn json_to_plan() {
async fn json_to_plan() {
let input = r#"{"emptyRelation":{}}"#.to_string();
let ctx = SessionContext::new();
let actual = logical_plan_from_json(&input, &ctx).unwrap();
let actual = logical_plan_from_json(&input, &ctx).await.unwrap();
let result = matches!(actual, LogicalPlan::EmptyRelation(_));
assert!(result, "Should parse empty relation");
}
Expand Down
Loading