Skip to content

Commit

Permalink
add node for CteWorkTableScan in datafusion-proto
Browse files Browse the repository at this point in the history
  • Loading branch information
leoyvens committed Nov 11, 2024
1 parent c4df917 commit d0401cf
Show file tree
Hide file tree
Showing 6 changed files with 178 additions and 54 deletions.
12 changes: 10 additions & 2 deletions datafusion/core/src/datasource/cte_worktable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,6 @@ use crate::datasource::{TableProvider, TableType};
#[derive(Debug)]
pub struct CteWorkTable {
/// The name of the CTE work table
// WIP, see https://github.com/apache/datafusion/issues/462
#[allow(dead_code)]
name: String,
/// This schema must be shared across both the static and recursive terms of a recursive query
table_schema: SchemaRef,
Expand All @@ -56,6 +54,16 @@ impl CteWorkTable {
table_schema,
}
}

/// The user-provided name of the CTE
pub fn name(&self) -> &str {
&self.name
}

/// The schema of the recursive term of the query
pub fn schema(&self) -> SchemaRef {
self.table_schema.clone()
}
}

#[async_trait]
Expand Down
6 changes: 6 additions & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ message LogicalPlanNode {
CopyToNode copy_to = 29;
UnnestNode unnest = 30;
RecursiveQueryNode recursive_query = 31;
CteWorkTableScanNode cte_work_table_scan = 32;
}
}

Expand Down Expand Up @@ -1257,3 +1258,8 @@ message RecursiveQueryNode {
LogicalPlanNode recursive_term = 3;
bool is_distinct = 4;
}

message CteWorkTableScanNode {
string name = 1;
datafusion_common.Schema schema = 2;
}
122 changes: 122 additions & 0 deletions datafusion/proto/src/generated/pbjson.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 28 additions & 2 deletions datafusion/proto/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use std::sync::Arc;

use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan;
use crate::protobuf::{
ColumnUnnestListItem, ColumnUnnestListRecursion, CustomTableScanNode,
SortExprNodeCollection,
ColumnUnnestListItem, ColumnUnnestListRecursion, CteWorkTableScanNode,
CustomTableScanNode, SortExprNodeCollection,
};
use crate::{
convert_required, into_required,
Expand All @@ -34,6 +34,7 @@ use crate::{

use crate::protobuf::{proto_error, ToProtoError};
use arrow::datatypes::{DataType, Schema, SchemaRef};
use datafusion::datasource::cte_worktable::CteWorkTable;
#[cfg(feature = "parquet")]
use datafusion::datasource::file_format::parquet::ParquetFormat;
use datafusion::datasource::file_format::{
Expand Down Expand Up @@ -931,6 +932,17 @@ impl AsLogicalPlan for LogicalPlanNode {
is_distinct: recursive_query_node.is_distinct,
}))
}
LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => {
let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node;
let schema = convert_required!(*schema)?;
let cte_work_table = CteWorkTable::new(name.as_str(), Arc::new(schema));
LogicalPlanBuilder::scan(
name.as_str(),
provider_as_source(Arc::new(cte_work_table)),
None,
)?
.build()
}
}
}

Expand Down Expand Up @@ -1087,6 +1099,20 @@ impl AsLogicalPlan for LogicalPlanNode {
},
))),
})
} else if let Some(cte_work_table) = source.downcast_ref::<CteWorkTable>()
{
let name = cte_work_table.name().to_string();
let schema = cte_work_table.schema();
let schema: protobuf::Schema = schema.as_ref().try_into()?;

Ok(LogicalPlanNode {
logical_plan_type: Some(LogicalPlanType::CteWorkTableScan(
protobuf::CteWorkTableScanNode {
name,
schema: Some(schema),
},
)),
})
} else {
let mut bytes = vec![];
extension_codec
Expand Down
51 changes: 2 additions & 49 deletions datafusion/proto/tests/cases/roundtrip_logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ use arrow::datatypes::{
};
use arrow::util::pretty::pretty_format_batches;
use datafusion::datasource::file_format::json::JsonFormatFactory;
use datafusion::datasource::MemTable;
use datafusion_common::parsers::CompressionTypeVariant;
use prost::Message;
use std::any::Any;
Expand Down Expand Up @@ -2528,48 +2527,6 @@ fn roundtrip_window() {

#[tokio::test]
async fn roundtrip_recursive_query() {
#[derive(Debug)]
pub struct EmptyTableCodec;

impl LogicalExtensionCodec for EmptyTableCodec {
fn try_decode(
&self,
_buf: &[u8],
_inputs: &[LogicalPlan],
_ctx: &SessionContext,
) -> Result<Extension, DataFusionError> {
not_impl_err!("No extension codec provided")
}

fn try_encode(
&self,
_node: &Extension,
_buf: &mut Vec<u8>,
) -> Result<(), DataFusionError> {
not_impl_err!("No extension codec provided")
}

fn try_decode_table_provider(
&self,
_buf: &[u8],
_table_ref: &TableReference,
schema: SchemaRef,
_ctx: &SessionContext,
) -> Result<Arc<dyn TableProvider>, DataFusionError> {
let table = MemTable::try_new(schema, vec![vec![]])?;
Ok(Arc::new(table))
}

fn try_encode_table_provider(
&self,
_table_ref: &TableReference,
_node: Arc<dyn TableProvider>,
_buf: &mut Vec<u8>,
) -> Result<(), DataFusionError> {
Ok(())
}
}

let query = "WITH RECURSIVE cte AS (
SELECT 1 as n
UNION ALL
Expand All @@ -2581,14 +2538,10 @@ async fn roundtrip_recursive_query() {
let dataframe = ctx.sql(query).await.unwrap();
let plan = dataframe.logical_plan().clone();
let output = dataframe.collect().await.unwrap();
let extension_codec = EmptyTableCodec {};
let bytes =
logical_plan_to_bytes_with_extension_codec(&plan, &extension_codec).unwrap();
let bytes = logical_plan_to_bytes(&plan).unwrap();

let ctx = SessionContext::new();
let logical_round_trip =
logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec)
.unwrap();
let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap();
assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}"));
let dataframe = ctx.execute_logical_plan(logical_round_trip).await.unwrap();
let output_round_trip = dataframe.collect().await.unwrap();
Expand Down

0 comments on commit d0401cf

Please sign in to comment.