diff --git a/datafusion/core/src/datasource/cte_worktable.rs b/datafusion/core/src/datasource/cte_worktable.rs index 23f57b12ae08..371c635c6e79 100644 --- a/datafusion/core/src/datasource/cte_worktable.rs +++ b/datafusion/core/src/datasource/cte_worktable.rs @@ -39,8 +39,6 @@ use crate::datasource::{TableProvider, TableType}; #[derive(Debug)] pub struct CteWorkTable { /// The name of the CTE work table - // WIP, see https://github.com/apache/datafusion/issues/462 - #[allow(dead_code)] name: String, /// This schema must be shared across both the static and recursive terms of a recursive query table_schema: SchemaRef, @@ -56,6 +54,16 @@ impl CteWorkTable { table_schema, } } + + /// The user-provided name of the CTE + pub fn name(&self) -> &str { + &self.name + } + + /// The schema of the recursive term of the query + pub fn schema(&self) -> SchemaRef { + self.table_schema.clone() + } } #[async_trait] diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 52b91ba377ec..b9a1cff94d05 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -60,6 +60,7 @@ message LogicalPlanNode { CopyToNode copy_to = 29; UnnestNode unnest = 30; RecursiveQueryNode recursive_query = 31; + CteWorkTableScanNode cte_work_table_scan = 32; } } @@ -1257,3 +1258,8 @@ message RecursiveQueryNode { LogicalPlanNode recursive_term = 3; bool is_distinct = 4; } + +message CteWorkTableScanNode { + string name = 1; + datafusion_common.Schema schema = 2; +} diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 7b454a32628b..52ba1ea8aa79 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -4153,6 +4153,114 @@ impl<'de> serde::Deserialize<'de> for CsvSinkExecNode { deserializer.deserialize_struct("datafusion.CsvSinkExecNode", FIELDS, GeneratedVisitor) } } +impl serde::Serialize for CteWorkTableScanNode { + #[allow(deprecated)] + fn serialize(&self, serializer: S) -> std::result::Result + where + S: serde::Serializer, + { + use serde::ser::SerializeStruct; + let mut len = 0; + if !self.name.is_empty() { + len += 1; + } + if self.schema.is_some() { + len += 1; + } + let mut struct_ser = serializer.serialize_struct("datafusion.CteWorkTableScanNode", len)?; + if !self.name.is_empty() { + struct_ser.serialize_field("name", &self.name)?; + } + if let Some(v) = self.schema.as_ref() { + struct_ser.serialize_field("schema", v)?; + } + struct_ser.end() + } +} +impl<'de> serde::Deserialize<'de> for CteWorkTableScanNode { + #[allow(deprecated)] + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + const FIELDS: &[&str] = &[ + "name", + "schema", + ]; + + #[allow(clippy::enum_variant_names)] + enum GeneratedField { + Name, + Schema, + } + impl<'de> serde::Deserialize<'de> for GeneratedField { + fn deserialize(deserializer: D) -> std::result::Result + where + D: serde::Deserializer<'de>, + { + struct GeneratedVisitor; + + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = GeneratedField; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "expected one of: {:?}", &FIELDS) + } + + #[allow(unused_variables)] + fn visit_str(self, value: &str) -> std::result::Result + where + E: serde::de::Error, + { + match value { + "name" => Ok(GeneratedField::Name), + "schema" => Ok(GeneratedField::Schema), + _ => Err(serde::de::Error::unknown_field(value, FIELDS)), + } + } + } + deserializer.deserialize_identifier(GeneratedVisitor) + } + } + struct GeneratedVisitor; + impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { + type Value = CteWorkTableScanNode; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct datafusion.CteWorkTableScanNode") + } + + fn visit_map(self, mut map_: V) -> std::result::Result + where + V: serde::de::MapAccess<'de>, + { + let mut name__ = None; + let mut schema__ = None; + while let Some(k) = map_.next_key()? { + match k { + GeneratedField::Name => { + if name__.is_some() { + return Err(serde::de::Error::duplicate_field("name")); + } + name__ = Some(map_.next_value()?); + } + GeneratedField::Schema => { + if schema__.is_some() { + return Err(serde::de::Error::duplicate_field("schema")); + } + schema__ = map_.next_value()?; + } + } + } + Ok(CteWorkTableScanNode { + name: name__.unwrap_or_default(), + schema: schema__, + }) + } + } + deserializer.deserialize_struct("datafusion.CteWorkTableScanNode", FIELDS, GeneratedVisitor) + } +} impl serde::Serialize for CubeNode { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -10605,6 +10713,9 @@ impl serde::Serialize for LogicalPlanNode { logical_plan_node::LogicalPlanType::RecursiveQuery(v) => { struct_ser.serialize_field("recursiveQuery", v)?; } + logical_plan_node::LogicalPlanType::CteWorkTableScan(v) => { + struct_ser.serialize_field("cteWorkTableScan", v)?; + } } } struct_ser.end() @@ -10661,6 +10772,8 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "unnest", "recursive_query", "recursiveQuery", + "cte_work_table_scan", + "cteWorkTableScan", ]; #[allow(clippy::enum_variant_names)] @@ -10695,6 +10808,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { CopyTo, Unnest, RecursiveQuery, + CteWorkTableScan, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -10746,6 +10860,7 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { "copyTo" | "copy_to" => Ok(GeneratedField::CopyTo), "unnest" => Ok(GeneratedField::Unnest), "recursiveQuery" | "recursive_query" => Ok(GeneratedField::RecursiveQuery), + "cteWorkTableScan" | "cte_work_table_scan" => Ok(GeneratedField::CteWorkTableScan), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -10976,6 +11091,13 @@ impl<'de> serde::Deserialize<'de> for LogicalPlanNode { return Err(serde::de::Error::duplicate_field("recursiveQuery")); } logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::RecursiveQuery) +; + } + GeneratedField::CteWorkTableScan => { + if logical_plan_type__.is_some() { + return Err(serde::de::Error::duplicate_field("cteWorkTableScan")); + } + logical_plan_type__ = map_.next_value::<::std::option::Option<_>>()?.map(logical_plan_node::LogicalPlanType::CteWorkTableScan) ; } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 9ca4836fe4c0..c7f5606049c0 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -5,7 +5,7 @@ pub struct LogicalPlanNode { #[prost( oneof = "logical_plan_node::LogicalPlanType", - tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31" + tags = "1, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32" )] pub logical_plan_type: ::core::option::Option, } @@ -73,6 +73,8 @@ pub mod logical_plan_node { Unnest(::prost::alloc::boxed::Box), #[prost(message, tag = "31")] RecursiveQuery(::prost::alloc::boxed::Box), + #[prost(message, tag = "32")] + CteWorkTableScan(super::CteWorkTableScanNode), } } #[derive(Clone, PartialEq, ::prost::Message)] @@ -1826,6 +1828,13 @@ pub struct RecursiveQueryNode { #[prost(bool, tag = "4")] pub is_distinct: bool, } +#[derive(Clone, PartialEq, ::prost::Message)] +pub struct CteWorkTableScanNode { + #[prost(string, tag = "1")] + pub name: ::prost::alloc::string::String, + #[prost(message, optional, tag = "2")] + pub schema: ::core::option::Option, +} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum BuiltInWindowFunction { diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index a89b61684141..50636048ebc9 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -21,8 +21,8 @@ use std::sync::Arc; use crate::protobuf::logical_plan_node::LogicalPlanType::CustomScan; use crate::protobuf::{ - ColumnUnnestListItem, ColumnUnnestListRecursion, CustomTableScanNode, - SortExprNodeCollection, + ColumnUnnestListItem, ColumnUnnestListRecursion, CteWorkTableScanNode, + CustomTableScanNode, SortExprNodeCollection, }; use crate::{ convert_required, into_required, @@ -34,6 +34,7 @@ use crate::{ use crate::protobuf::{proto_error, ToProtoError}; use arrow::datatypes::{DataType, Schema, SchemaRef}; +use datafusion::datasource::cte_worktable::CteWorkTable; #[cfg(feature = "parquet")] use datafusion::datasource::file_format::parquet::ParquetFormat; use datafusion::datasource::file_format::{ @@ -931,6 +932,17 @@ impl AsLogicalPlan for LogicalPlanNode { is_distinct: recursive_query_node.is_distinct, })) } + LogicalPlanType::CteWorkTableScan(cte_work_table_scan_node) => { + let CteWorkTableScanNode { name, schema } = cte_work_table_scan_node; + let schema = convert_required!(*schema)?; + let cte_work_table = CteWorkTable::new(name.as_str(), Arc::new(schema)); + LogicalPlanBuilder::scan( + name.as_str(), + provider_as_source(Arc::new(cte_work_table)), + None, + )? + .build() + } } } @@ -1087,6 +1099,20 @@ impl AsLogicalPlan for LogicalPlanNode { }, ))), }) + } else if let Some(cte_work_table) = source.downcast_ref::() + { + let name = cte_work_table.name().to_string(); + let schema = cte_work_table.schema(); + let schema: protobuf::Schema = schema.as_ref().try_into()?; + + Ok(LogicalPlanNode { + logical_plan_type: Some(LogicalPlanType::CteWorkTableScan( + protobuf::CteWorkTableScanNode { + name, + schema: Some(schema), + }, + )), + }) } else { let mut bytes = vec![]; extension_codec diff --git a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs index ca391cd33d4a..8445cdc761ed 100644 --- a/datafusion/proto/tests/cases/roundtrip_logical_plan.rs +++ b/datafusion/proto/tests/cases/roundtrip_logical_plan.rs @@ -25,7 +25,6 @@ use arrow::datatypes::{ }; use arrow::util::pretty::pretty_format_batches; use datafusion::datasource::file_format::json::JsonFormatFactory; -use datafusion::datasource::MemTable; use datafusion_common::parsers::CompressionTypeVariant; use prost::Message; use std::any::Any; @@ -2528,48 +2527,6 @@ fn roundtrip_window() { #[tokio::test] async fn roundtrip_recursive_query() { - #[derive(Debug)] - pub struct EmptyTableCodec; - - impl LogicalExtensionCodec for EmptyTableCodec { - fn try_decode( - &self, - _buf: &[u8], - _inputs: &[LogicalPlan], - _ctx: &SessionContext, - ) -> Result { - not_impl_err!("No extension codec provided") - } - - fn try_encode( - &self, - _node: &Extension, - _buf: &mut Vec, - ) -> Result<(), DataFusionError> { - not_impl_err!("No extension codec provided") - } - - fn try_decode_table_provider( - &self, - _buf: &[u8], - _table_ref: &TableReference, - schema: SchemaRef, - _ctx: &SessionContext, - ) -> Result, DataFusionError> { - let table = MemTable::try_new(schema, vec![vec![]])?; - Ok(Arc::new(table)) - } - - fn try_encode_table_provider( - &self, - _table_ref: &TableReference, - _node: Arc, - _buf: &mut Vec, - ) -> Result<(), DataFusionError> { - Ok(()) - } - } - let query = "WITH RECURSIVE cte AS ( SELECT 1 as n UNION ALL @@ -2581,14 +2538,10 @@ async fn roundtrip_recursive_query() { let dataframe = ctx.sql(query).await.unwrap(); let plan = dataframe.logical_plan().clone(); let output = dataframe.collect().await.unwrap(); - let extension_codec = EmptyTableCodec {}; - let bytes = - logical_plan_to_bytes_with_extension_codec(&plan, &extension_codec).unwrap(); + let bytes = logical_plan_to_bytes(&plan).unwrap(); let ctx = SessionContext::new(); - let logical_round_trip = - logical_plan_from_bytes_with_extension_codec(&bytes, &ctx, &extension_codec) - .unwrap(); + let logical_round_trip = logical_plan_from_bytes(&bytes, &ctx).unwrap(); assert_eq!(format!("{plan:?}"), format!("{logical_round_trip:?}")); let dataframe = ctx.execute_logical_plan(logical_round_trip).await.unwrap(); let output_round_trip = dataframe.collect().await.unwrap();