From 531d541c540f2bdcd95fdae878b5459e8ee68b2b Mon Sep 17 00:00:00 2001 From: Matthew Gapp <61894094+matthewgapp@users.noreply.github.com> Date: Wed, 10 Jan 2024 16:45:42 -0800 Subject: [PATCH] add sql -> logical plan support * impl cte as work table * move SharedState to continuance * impl WorkTableState wip: readying pr to implement only logical plan --- datafusion-cli/Cargo.lock | 2 + datafusion/common/src/dfschema.rs | 5 + datafusion/core/src/datasource/cte.rs | 89 +++++++++++ datafusion/core/src/datasource/mod.rs | 1 + datafusion/core/src/execution/context/mod.rs | 10 ++ datafusion/core/src/physical_planner.rs | 7 +- datafusion/expr/src/logical_plan/builder.rs | 23 +++ datafusion/expr/src/logical_plan/mod.rs | 4 +- datafusion/expr/src/logical_plan/plan.rs | 46 ++++++ .../optimizer/src/common_subexpr_eliminate.rs | 1 + .../optimizer/src/optimize_projections.rs | 1 + datafusion/proto/src/logical_plan/mod.rs | 3 + datafusion/sql/src/planner.rs | 9 ++ datafusion/sql/src/query.rs | 149 +++++++++++++++--- datafusion/sql/tests/sql_integration.rs | 2 +- 15 files changed, 323 insertions(+), 29 deletions(-) create mode 100644 datafusion/core/src/datasource/cte.rs diff --git a/datafusion-cli/Cargo.lock b/datafusion-cli/Cargo.lock index 252b00ca0adc..6f1c934c0855 100644 --- a/datafusion-cli/Cargo.lock +++ b/datafusion-cli/Cargo.lock @@ -1208,6 +1208,7 @@ dependencies = [ "parking_lot", "rand", "tempfile", + "tokio", "url", ] @@ -1299,6 +1300,7 @@ dependencies = [ "pin-project-lite", "rand", "tokio", + "tokio-stream", "uuid", ] diff --git a/datafusion/common/src/dfschema.rs b/datafusion/common/src/dfschema.rs index 85b97aac037d..a33973790c60 100644 --- a/datafusion/common/src/dfschema.rs +++ b/datafusion/common/src/dfschema.rs @@ -915,6 +915,11 @@ impl DFField { self.field = f.into(); self } + + pub fn with_qualifier(mut self, qualifier: impl Into) -> Self { + self.qualifier = Some(qualifier.into()); + self + } } impl From for DFField { diff --git a/datafusion/core/src/datasource/cte.rs b/datafusion/core/src/datasource/cte.rs new file mode 100644 index 000000000000..9fb241f49db3 --- /dev/null +++ b/datafusion/core/src/datasource/cte.rs @@ -0,0 +1,89 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! CteWorkTable implementation used for recursive queries + +use std::any::Any; +use std::sync::Arc; + +use arrow::datatypes::SchemaRef; +use async_trait::async_trait; +use datafusion_common::not_impl_err; + +use crate::{ + error::Result, + logical_expr::{Expr, LogicalPlan, TableProviderFilterPushDown}, + physical_plan::ExecutionPlan, +}; + +use datafusion_common::DataFusionError; + +use crate::datasource::{TableProvider, TableType}; +use crate::execution::context::SessionState; + +/// TODO: add docs +pub struct CteWorkTable { + name: String, + table_schema: SchemaRef, +} + +impl CteWorkTable { + /// TODO: add doc + pub fn new(name: &str, table_schema: SchemaRef) -> Self { + Self { + name: name.to_owned(), + table_schema, + } + } +} + +#[async_trait] +impl TableProvider for CteWorkTable { + fn as_any(&self) -> &dyn Any { + self + } + + fn get_logical_plan(&self) -> Option<&LogicalPlan> { + None + } + + fn schema(&self) -> SchemaRef { + self.table_schema.clone() + } + + fn table_type(&self) -> TableType { + TableType::Temporary + } + + async fn scan( + &self, + _state: &SessionState, + _projection: Option<&Vec>, + _filters: &[Expr], + _limit: Option, + ) -> Result> { + not_impl_err!("scan not implemented for CteWorkTable yet") + } + + fn supports_filter_pushdown( + &self, + _filter: &Expr, + ) -> Result { + // TODO: should we support filter pushdown? + Ok(TableProviderFilterPushDown::Unsupported) + } +} diff --git a/datafusion/core/src/datasource/mod.rs b/datafusion/core/src/datasource/mod.rs index 2e516cc36a01..93f197ec9438 100644 --- a/datafusion/core/src/datasource/mod.rs +++ b/datafusion/core/src/datasource/mod.rs @@ -20,6 +20,7 @@ //! [`ListingTable`]: crate::datasource::listing::ListingTable pub mod avro_to_arrow; +pub mod cte; pub mod default_table_source; pub mod empty; pub mod file_format; diff --git a/datafusion/core/src/execution/context/mod.rs b/datafusion/core/src/execution/context/mod.rs index d6b7f046f3e3..221faa3019ab 100644 --- a/datafusion/core/src/execution/context/mod.rs +++ b/datafusion/core/src/execution/context/mod.rs @@ -26,6 +26,7 @@ mod parquet; use crate::{ catalog::{CatalogList, MemoryCatalogList}, datasource::{ + cte::CteWorkTable, function::{TableFunction, TableFunctionImpl}, listing::{ListingOptions, ListingTable}, provider::TableProviderFactory, @@ -1899,6 +1900,15 @@ impl<'a> ContextProvider for SessionContextProvider<'a> { Ok(provider_as_source(provider)) } + fn create_cte_work_table( + &self, + name: &str, + schema: SchemaRef, + ) -> Result> { + let table = Arc::new(CteWorkTable::new(name, schema)); + Ok(provider_as_source(table)) + } + fn get_function_meta(&self, name: &str) -> Option> { self.state.scalar_functions().get(name).cloned() } diff --git a/datafusion/core/src/physical_planner.rs b/datafusion/core/src/physical_planner.rs index d696c55a8c13..46c18b3c7b33 100644 --- a/datafusion/core/src/physical_planner.rs +++ b/datafusion/core/src/physical_planner.rs @@ -87,8 +87,8 @@ use datafusion_expr::expr::{ use datafusion_expr::expr_rewriter::unnormalize_cols; use datafusion_expr::logical_plan::builder::wrap_projection_for_join_if_necessary; use datafusion_expr::{ - DescribeTable, DmlStatement, ScalarFunctionDefinition, StringifiedPlan, WindowFrame, - WindowFrameBound, WriteOp, + DescribeTable, DmlStatement, RecursiveQuery, ScalarFunctionDefinition, + StringifiedPlan, WindowFrame, WindowFrameBound, WriteOp, }; use datafusion_physical_expr::expressions::Literal; use datafusion_physical_plan::placeholder_row::PlaceholderRowExec; @@ -1311,6 +1311,9 @@ impl DefaultPhysicalPlanner { Ok(plan) } } + LogicalPlan::RecursiveQuery(RecursiveQuery { name: _, static_term: _, recursive_term: _, is_distinct: _,.. }) => { + not_impl_err!("Physical counterpart of RecursiveQuery is not implemented yet") + } }; exec_plan }.boxed() diff --git a/datafusion/expr/src/logical_plan/builder.rs b/datafusion/expr/src/logical_plan/builder.rs index 847fbbbf61c7..ef70d30f77b4 100644 --- a/datafusion/expr/src/logical_plan/builder.rs +++ b/datafusion/expr/src/logical_plan/builder.rs @@ -55,6 +55,8 @@ use datafusion_common::{ ScalarValue, TableReference, ToDFSchema, UnnestOptions, }; +use super::plan::RecursiveQuery; + /// Default table name for unnamed table pub const UNNAMED_TABLE: &str = "?table?"; @@ -121,6 +123,27 @@ impl LogicalPlanBuilder { })) } + /// Convert a regular plan into a recursive query. + pub fn to_recursive_query( + &self, + name: String, + recursive_term: LogicalPlan, + is_distinct: bool, + ) -> Result { + // TODO: we need to do a bunch of validation here. Maybe more. + if is_distinct { + return Err(DataFusionError::NotImplemented( + "Recursive queries with distinct is not supported".to_string(), + )); + } + Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery { + name, + static_term: Arc::new(self.plan.clone()), + recursive_term: Arc::new(recursive_term), + is_distinct, + }))) + } + /// Create a values list based relation, and the schema is inferred from data, consuming /// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/expr/src/logical_plan/mod.rs b/datafusion/expr/src/logical_plan/mod.rs index bc722dd69ace..f6e6000897a5 100644 --- a/datafusion/expr/src/logical_plan/mod.rs +++ b/datafusion/expr/src/logical_plan/mod.rs @@ -36,8 +36,8 @@ pub use plan::{ projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct, DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint, JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection, - Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan, - ToStringifiedPlan, Union, Unnest, Values, Window, + RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, + TableScan, ToStringifiedPlan, Union, Unnest, Values, Window, }; pub use statement::{ SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd, diff --git a/datafusion/expr/src/logical_plan/plan.rs b/datafusion/expr/src/logical_plan/plan.rs index 93a38fb40df5..1cf325caa90c 100644 --- a/datafusion/expr/src/logical_plan/plan.rs +++ b/datafusion/expr/src/logical_plan/plan.rs @@ -154,6 +154,8 @@ pub enum LogicalPlan { /// Unnest a column that contains a nested list type such as an /// ARRAY. This is used to implement SQL `UNNEST` Unnest(Unnest), + /// A variadic query (e.g. "Recursive CTEs") + RecursiveQuery(RecursiveQuery), } impl LogicalPlan { @@ -191,6 +193,10 @@ impl LogicalPlan { LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(), LogicalPlan::Ddl(ddl) => ddl.schema(), LogicalPlan::Unnest(Unnest { schema, .. }) => schema, + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // we take the schema of the static term as the schema of the entire recursive query + static_term.schema() + } } } @@ -243,6 +249,10 @@ impl LogicalPlan { | LogicalPlan::TableScan(_) => { vec![self.schema()] } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + // return only the schema of the static term + static_term.all_schemas() + } // return children schemas LogicalPlan::Limit(_) | LogicalPlan::Subquery(_) @@ -384,6 +394,7 @@ impl LogicalPlan { .try_for_each(f), // plans without expressions LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Subquery(_) | LogicalPlan::SubqueryAlias(_) | LogicalPlan::Limit(_) @@ -430,6 +441,11 @@ impl LogicalPlan { LogicalPlan::Ddl(ddl) => ddl.inputs(), LogicalPlan::Unnest(Unnest { input, .. }) => vec![input], LogicalPlan::Prepare(Prepare { input, .. }) => vec![input], + LogicalPlan::RecursiveQuery(RecursiveQuery { + static_term, + recursive_term, + .. + }) => vec![static_term, recursive_term], // plans without inputs LogicalPlan::TableScan { .. } | LogicalPlan::Statement { .. } @@ -510,6 +526,9 @@ impl LogicalPlan { cross.left.head_output_expr() } } + LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => { + static_term.head_output_expr() + } LogicalPlan::Union(union) => Ok(Some(Expr::Column( union.schema.fields()[0].qualified_column(), ))), @@ -835,6 +854,14 @@ impl LogicalPlan { }; Ok(LogicalPlan::Distinct(distinct)) } + LogicalPlan::RecursiveQuery(RecursiveQuery { + name, is_distinct, .. + }) => Ok(LogicalPlan::RecursiveQuery(RecursiveQuery { + name: name.clone(), + static_term: Arc::new(inputs[0].clone()), + recursive_term: Arc::new(inputs[1].clone()), + is_distinct: *is_distinct, + })), LogicalPlan::Analyze(a) => { assert!(expr.is_empty()); assert_eq!(inputs.len(), 1); @@ -1073,6 +1100,7 @@ impl LogicalPlan { }), LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch, LogicalPlan::EmptyRelation(_) => Some(0), + LogicalPlan::RecursiveQuery(_) => None, LogicalPlan::Subquery(_) => None, LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(), LogicalPlan::Limit(Limit { fetch, .. }) => *fetch, @@ -1408,6 +1436,11 @@ impl LogicalPlan { fn fmt(&self, f: &mut Formatter) -> fmt::Result { match self.0 { LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"), + LogicalPlan::RecursiveQuery(RecursiveQuery { + is_distinct, .. + }) => { + write!(f, "RecursiveQuery: is_distinct={}", is_distinct) + } LogicalPlan::Values(Values { ref values, .. }) => { let str_values: Vec<_> = values .iter() @@ -1718,6 +1751,19 @@ pub struct EmptyRelation { pub schema: DFSchemaRef, } +/// A variadic query operation +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct RecursiveQuery { + /// Name of the query + pub name: String, + /// The static term + pub static_term: Arc, + /// The recursive term + pub recursive_term: Arc, + /// Distinction + pub is_distinct: bool, +} + /// Values expression. See /// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html) /// documentation for more details. diff --git a/datafusion/optimizer/src/common_subexpr_eliminate.rs b/datafusion/optimizer/src/common_subexpr_eliminate.rs index 1e089257c61a..1367962da68c 100644 --- a/datafusion/optimizer/src/common_subexpr_eliminate.rs +++ b/datafusion/optimizer/src/common_subexpr_eliminate.rs @@ -364,6 +364,7 @@ impl OptimizerRule for CommonSubexprEliminate { | LogicalPlan::Dml(_) | LogicalPlan::Copy(_) | LogicalPlan::Unnest(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Prepare(_) => { // apply the optimization to all inputs of the plan utils::optimize_children(self, plan, config)? diff --git a/datafusion/optimizer/src/optimize_projections.rs b/datafusion/optimizer/src/optimize_projections.rs index 1d4eda0bd23e..2e63dfc2b793 100644 --- a/datafusion/optimizer/src/optimize_projections.rs +++ b/datafusion/optimizer/src/optimize_projections.rs @@ -162,6 +162,7 @@ fn optimize_projections( .collect::>() } LogicalPlan::EmptyRelation(_) + | LogicalPlan::RecursiveQuery(_) | LogicalPlan::Statement(_) | LogicalPlan::Values(_) | LogicalPlan::Extension(_) diff --git a/datafusion/proto/src/logical_plan/mod.rs b/datafusion/proto/src/logical_plan/mod.rs index e8a38784481b..c06f84f9f698 100644 --- a/datafusion/proto/src/logical_plan/mod.rs +++ b/datafusion/proto/src/logical_plan/mod.rs @@ -1702,6 +1702,9 @@ impl AsLogicalPlan for LogicalPlanNode { LogicalPlan::DescribeTable(_) => Err(proto_error( "LogicalPlan serde is not yet implemented for DescribeTable", )), + LogicalPlan::RecursiveQuery(_) => Err(proto_error( + "LogicalPlan serde is not yet implemented for RecursiveQuery", + )), } } } diff --git a/datafusion/sql/src/planner.rs b/datafusion/sql/src/planner.rs index a04df5589b85..2b11943324da 100644 --- a/datafusion/sql/src/planner.rs +++ b/datafusion/sql/src/planner.rs @@ -61,6 +61,15 @@ pub trait ContextProvider { not_impl_err!("Table Functions are not supported") } + /// TODO: add doc + fn create_cte_work_table( + &self, + _name: &str, + _schema: SchemaRef, + ) -> Result> { + not_impl_err!("Recursive CTE is not supported") + } + /// Getter for a UDF description fn get_function_meta(&self, name: &str) -> Option>; /// Getter for a UDAF description diff --git a/datafusion/sql/src/query.rs b/datafusion/sql/src/query.rs index 388377e3ee6b..a69752187542 100644 --- a/datafusion/sql/src/query.rs +++ b/datafusion/sql/src/query.rs @@ -19,6 +19,7 @@ use std::sync::Arc; use crate::planner::{ContextProvider, PlannerContext, SqlToRel}; +use arrow::datatypes::Schema; use datafusion_common::{ not_impl_err, plan_err, sql_err, Constraints, DataFusionError, Result, ScalarValue, }; @@ -26,7 +27,8 @@ use datafusion_expr::{ CreateMemoryTable, DdlStatement, Distinct, Expr, LogicalPlan, LogicalPlanBuilder, }; use sqlparser::ast::{ - Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, Value, + Expr as SQLExpr, Offset as SQLOffset, OrderByExpr, Query, SetExpr, SetOperator, + SetQuantifier, Value, }; use sqlparser::parser::ParserError::ParserError; @@ -52,21 +54,8 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { let set_expr = query.body; if let Some(with) = query.with { // Process CTEs from top to bottom - // do not allow self-references - if with.recursive { - if self - .context_provider - .options() - .execution - .enable_recursive_ctes - { - return plan_err!( - "Recursive CTEs are enabled but are not yet supported" - ); - } else { - return not_impl_err!("Recursive CTEs are not supported"); - } - } + + let is_recursive = with.recursive; for cte in with.cte_tables { // A `WITH` block can't use the same name more than once @@ -76,16 +65,128 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> { "WITH query name {cte_name:?} specified more than once" ))); } - // create logical plan & pass backreferencing CTEs - // CTE expr don't need extend outer_query_schema - let logical_plan = - self.query_to_plan(*cte.query, &mut planner_context.clone())?; + let cte_query = cte.query; + + if is_recursive { + if !self + .context_provider + .options() + .execution + .enable_recursive_ctes + { + return not_impl_err!("Recursive CTEs are not enabled"); + } + + match *cte_query.body { + SetExpr::SetOperation { + op: SetOperator::Union, + left, + right, + set_quantifier, + } => { + let distinct = set_quantifier != SetQuantifier::All; + + // Each recursive CTE consists from two parts in the logical plan: + // 1. A static term (the left hand side on the SQL, where the + // referencing to the same CTE is not allowed) + // + // 2. A recursive term (the right hand side, and the recursive + // part) + + // Since static term does not have any specific properties, it can + // be compiled as if it was a regular expression. This will + // allow us to infer the schema to be used in the recursive term. + + // ---------- Step 1: Compile the static term ------------------ + let static_plan = self + .set_expr_to_plan(*left, &mut planner_context.clone())?; + + // Since the recursive CTEs include a component that references a + // table with its name, like the example below: + // + // WITH RECURSIVE values(n) AS ( + // SELECT 1 as n -- static term + // UNION ALL + // SELECT n + 1 + // FROM values -- self reference + // WHERE n < 100 + // ) + // + // We need a temporary 'relation' to be referenced and used. PostgreSQL + // calls this a 'working table', but it is entirely an implementation + // detail and a 'real' table with that name might not even exist (as + // in the case of DataFusion). + // + // Since we can't simply register a table during planning stage (it is + // an execution problem), we'll use a relation object that preserves the + // schema of the input perfectly and also knows which recursive CTE it is + // bound to. - // Each `WITH` block can change the column names in the last - // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). - let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + // ---------- Step 2: Create a temporary relation ------------------ + // Step 2.1: Create a table source for the temporary relation + let work_table_source = + self.context_provider.create_cte_work_table( + &cte_name, + Arc::new(Schema::from(static_plan.schema().as_ref())), + )?; - planner_context.insert_cte(cte_name, logical_plan); + // Step 2.2: Create a temporary relation logical plan that will be used + // as the input to the recursive term + let work_table_plan = LogicalPlanBuilder::scan( + cte_name.to_string(), + work_table_source, + None, + )? + .build()?; + + let name = cte_name.clone(); + + // Step 2.3: Register the temporary relation in the planning context + // For all the self references in the variadic term, we'll replace it + // with the temporary relation we created above by temporarily registering + // it as a CTE. This temporary relation in the planning context will be + // replaced by the actual CTE plan once we're done with the planning. + planner_context.insert_cte(cte_name.clone(), work_table_plan); + + // ---------- Step 3: Compile the recursive term ------------------ + // this uses the named_relation we inserted above to resolve the + // relation. This ensures that the recursive term uses the named relation logical plan + // and thus the 'continuance' physical plan as its input and source + let recursive_plan = self + .set_expr_to_plan(*right, &mut planner_context.clone())?; + + // ---------- Step 4: Create the final plan ------------------ + // Step 4.1: Compile the final plan + let logical_plan = LogicalPlanBuilder::from(static_plan) + .to_recursive_query(name, recursive_plan, distinct)? + .build()?; + + let final_plan = + self.apply_table_alias(logical_plan, cte.alias)?; + + // Step 4.2: Remove the temporary relation from the planning context and replace it + // with the final plan. + planner_context.insert_cte(cte_name.clone(), final_plan); + } + _ => { + return Err(DataFusionError::SQL( + ParserError("Invalid recursive CTE".to_string()), + None, + )); + } + }; + } else { + // create logical plan & pass backreferencing CTEs + // CTE expr don't need extend outer_query_schema + let logical_plan = + self.query_to_plan(*cte_query, &mut planner_context.clone())?; + + // Each `WITH` block can change the column names in the last + // projection (e.g. "WITH table(t1, t2) AS SELECT 1, 2"). + let logical_plan = self.apply_table_alias(logical_plan, cte.alias)?; + + planner_context.insert_cte(cte_name, logical_plan); + } } } let plan = self.set_expr_to_plan(*(set_expr.clone()), planner_context)?; diff --git a/datafusion/sql/tests/sql_integration.rs b/datafusion/sql/tests/sql_integration.rs index 4de08a7124cf..14909e038f13 100644 --- a/datafusion/sql/tests/sql_integration.rs +++ b/datafusion/sql/tests/sql_integration.rs @@ -1398,7 +1398,7 @@ fn recursive_ctes() { select * from numbers;"; let err = logical_plan(sql).expect_err("query should have failed"); assert_eq!( - "This feature is not implemented: Recursive CTEs are not supported", + "Recursive CTEs are not enabled", err.strip_backtrace() ); }