Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for recursive CTEs #7581

Closed
wants to merge 10 commits into from
2 changes: 2 additions & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 5 additions & 0 deletions datafusion/common/src/dfschema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -915,6 +915,11 @@ impl DFField {
self.field = f.into();
self
}

pub fn with_qualifier(mut self, qualifier: impl Into<OwnedTableReference>) -> Self {
self.qualifier = Some(qualifier.into());
self
}
}

impl From<FieldRef> for DFField {
Expand Down
100 changes: 76 additions & 24 deletions datafusion/core/src/physical_planner.rs

Large diffs are not rendered by default.

8 changes: 8 additions & 0 deletions datafusion/execution/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,12 @@ object_store = { workspace = true }
parking_lot = { workspace = true }
rand = { workspace = true }
tempfile = { workspace = true }
tokio = { version = "1.28", features = [
"macros",
"rt",
"rt-multi-thread",
"sync",
"fs",
"parking_lot",
] }
url = { workspace = true }
41 changes: 41 additions & 0 deletions datafusion/execution/src/task.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,15 @@ use crate::{
runtime_env::{RuntimeConfig, RuntimeEnv},
};

use arrow::record_batch::RecordBatch;
// use futures::channel::mpsc::Receiver as SingleChannelReceiver;
use tokio::sync::mpsc::Receiver as SingleChannelReceiver;
// use futures::lock::Mutex;
use parking_lot::Mutex;
// use futures::

type RelationHandler = SingleChannelReceiver<Result<RecordBatch>>;

/// Task Execution Context
///
/// A [`TaskContext`] contains the state available during a single
Expand All @@ -56,6 +65,8 @@ pub struct TaskContext {
window_functions: HashMap<String, Arc<WindowUDF>>,
/// Runtime environment associated with this task context
runtime: Arc<RuntimeEnv>,
/// Registered relation handlers
relation_handlers: Mutex<HashMap<String, RelationHandler>>,
}

impl Default for TaskContext {
Expand All @@ -72,6 +83,7 @@ impl Default for TaskContext {
aggregate_functions: HashMap::new(),
window_functions: HashMap::new(),
runtime: Arc::new(runtime),
relation_handlers: Mutex::new(HashMap::new()),
}
}
}
Expand Down Expand Up @@ -99,6 +111,7 @@ impl TaskContext {
aggregate_functions,
window_functions,
runtime,
relation_handlers: Mutex::new(HashMap::new()),
}
}

Expand Down Expand Up @@ -171,6 +184,34 @@ impl TaskContext {
self.runtime = runtime;
self
}

/// Register a new relation handler. If a handler with the same name already exists
/// this function will return an error.
pub fn push_relation_handler(
&self,
name: String,
handler: RelationHandler,
) -> Result<()> {
let mut handlers = self.relation_handlers.lock();
if handlers.contains_key(&name) {
return Err(DataFusionError::Internal(format!(
"Relation handler {} already registered",
name
)));
}
handlers.insert(name, handler);
Ok(())
}

/// Retrieve the relation handler for the given name. It will remove the handler from
/// the storage if it exists, and return it as is.
pub fn pop_relation_handler(&self, name: String) -> Result<RelationHandler> {
let mut handlers = self.relation_handlers.lock();

handlers.remove(name.as_str()).ok_or_else(|| {
DataFusionError::Internal(format!("Relation handler {} not registered", name))
})
}
}

impl FunctionRegistry for TaskContext {
Expand Down
35 changes: 35 additions & 0 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,8 @@ use datafusion_common::{
ScalarValue, TableReference, ToDFSchema, UnnestOptions,
};

use super::plan::{NamedRelation, RecursiveQuery};

/// Default table name for unnamed table
pub const UNNAMED_TABLE: &str = "?table?";

Expand Down Expand Up @@ -121,6 +123,39 @@ impl LogicalPlanBuilder {
}))
}

/// A named temporary relation with a schema.
///
/// This is used to represent a relation that does not exist at the
matthewgapp marked this conversation as resolved.
Show resolved Hide resolved
/// planning stage, but will be created at execution time with the
/// given schema.
pub fn named_relation(name: &str, schema: DFSchemaRef) -> Self {
Self::from(LogicalPlan::NamedRelation(NamedRelation {
name: name.to_string(),
schema,
}))
}

/// Convert a regular plan into a recursive query.
pub fn to_recursive_query(
&self,
name: String,
recursive_term: LogicalPlan,
is_distinct: bool,
) -> Result<Self> {
// TODO: we need to do a bunch of validation here. Maybe more.
if is_distinct {
return Err(DataFusionError::NotImplemented(
"Recursive queries with distinct is not supported".to_string(),
));
}
Ok(Self::from(LogicalPlan::RecursiveQuery(RecursiveQuery {
name,
static_term: Arc::new(self.plan.clone()),
recursive_term: Arc::new(recursive_term),
is_distinct,
})))
}

/// Create a values list based relation, and the schema is inferred from data, consuming
/// `value`. See the [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html)
/// documentation for more details.
Expand Down
6 changes: 3 additions & 3 deletions datafusion/expr/src/logical_plan/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ pub use dml::{DmlStatement, WriteOp};
pub use plan::{
projection_schema, Aggregate, Analyze, CrossJoin, DescribeTable, Distinct,
DistinctOn, EmptyRelation, Explain, Extension, Filter, Join, JoinConstraint,
JoinType, Limit, LogicalPlan, Partitioning, PlanType, Prepare, Projection,
Repartition, Sort, StringifiedPlan, Subquery, SubqueryAlias, TableScan,
ToStringifiedPlan, Union, Unnest, Values, Window,
JoinType, Limit, LogicalPlan, NamedRelation, Partitioning, PlanType, Prepare,
Projection, RecursiveQuery, Repartition, Sort, StringifiedPlan, Subquery,
SubqueryAlias, TableScan, ToStringifiedPlan, Union, Unnest, Values, Window,
};
pub use statement::{
SetVariable, Statement, TransactionAccessMode, TransactionConclusion, TransactionEnd,
Expand Down
69 changes: 69 additions & 0 deletions datafusion/expr/src/logical_plan/plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,8 @@ pub enum LogicalPlan {
/// produces 0 or 1 row. This is used to implement SQL `SELECT`
/// that has no values in the `FROM` clause.
EmptyRelation(EmptyRelation),
/// A named temporary relation with a schema.
NamedRelation(NamedRelation),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am considering whether the NamedRelation and RecursiveQuery could be implemented as two TableSources, one being CTESelfRefTable and the other being CTERecursiveTable, and then use TableScan to read them.

Use CTESelfRefTable within the recursive term and CTERecursiveTable in the outer query.

But this idea is in its early stages and may be wrong.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@jonahgao, could you provide the rationale for your suggested strategy? I'm interested in understanding why it might be more effective than the current implementation. Performance is critical to our use case. And the implementation for recursion is very sensitive to performance considerations, as the setup for execution and stream management isn't amortized over all input record batches. Instead, it's incurred with each iteration. For instance, we've observed a substantial performance boost—up to 30 times faster—by eliminating certain intermediate nodes, like coalesce, from our plan (as evidenced in this PR). I've drafted another PR that appears to again double the speed of execution merely by omitting metric collection in recursive sub-graphs.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One rationale might be to make the implementation simpler -- if we could implement the recursive relation as a table provider, it would likely allow the changes to be more localized / smaller (e.g. maybe we could reuse MemTable::load to update the batches on each iteration)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically I understand the need to have LogicalPlan::RecursiveQuery but I don't (yet) understand the need to have the NamedRelation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NamedRelation is primarily a way to mirror batches back to the RecursiveQuery via its physical counterpart, ContinuanceExec

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthewgapp Another rationale might be to support pushing down filters to the working table, which may be useful if we support spilling the working table to disk in the future. I think the performance should not be affected, the execution of physical plans is almost the same as it is now.

I implemented a demo on this branch and in this commit. GitHub does not allow forking a repository twice, so I directly pushed it to another repository for convenience.

In this demo, I attempted to replace the NamedRelation with a TableProvider, namely CteWorkTable. The benefit of this is that it can avoid maintaining a new logical plan.

Another change is that I used a structure called WorkTable to connect the RecursiveQueryExec and the WorkTableExec (it was previously ContinuanceExec). The advantage of this is that it avoids maintaining some external context information, such as relation_handlers in TaskContext, and the ctx in create_initial_plan.

The WorkTable is a shared table, it will be scanned by the WorkTableExec during the execution of the recursive term, and after the execution is completed, the results will be written back to it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wow, tyty! I was in the process of implementing the shared table and my implementation turned out very similar to yours although I ended up working around the crate dependency graph constraints a bit differently by introducing a couple new traits. But I did end up exposing a method on the context to generate a table. I like your approach better.

I tested out your poc and performance remains about the same between my previous implementation and your new worktable approach! (which makes sense).

I'm going to go ahead and work based on your POC toward the list of PRs that Andrew wants to get this landed.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for your work and for the nexting contributions! @matthewgapp

/// Produces the output of running another query. This is used to
/// implement SQL subqueries
Subquery(Subquery),
Expand Down Expand Up @@ -154,6 +156,8 @@ pub enum LogicalPlan {
/// Unnest a column that contains a nested list type such as an
/// ARRAY. This is used to implement SQL `UNNEST`
Unnest(Unnest),
/// A variadic query (e.g. "Recursive CTEs")
RecursiveQuery(RecursiveQuery),
}

impl LogicalPlan {
Expand Down Expand Up @@ -191,6 +195,11 @@ impl LogicalPlan {
LogicalPlan::Copy(CopyTo { input, .. }) => input.schema(),
LogicalPlan::Ddl(ddl) => ddl.schema(),
LogicalPlan::Unnest(Unnest { schema, .. }) => schema,
LogicalPlan::NamedRelation(NamedRelation { schema, .. }) => schema,
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
// we take the schema of the static term as the schema of the entire recursive query
static_term.schema()
}
}
}

Expand Down Expand Up @@ -233,6 +242,7 @@ impl LogicalPlan {
LogicalPlan::Explain(_)
| LogicalPlan::Analyze(_)
| LogicalPlan::EmptyRelation(_)
| LogicalPlan::NamedRelation(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
Expand All @@ -243,6 +253,10 @@ impl LogicalPlan {
| LogicalPlan::TableScan(_) => {
vec![self.schema()]
}
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
// return only the schema of the static term
static_term.all_schemas()
}
// return children schemas
LogicalPlan::Limit(_)
| LogicalPlan::Subquery(_)
Expand Down Expand Up @@ -384,6 +398,9 @@ impl LogicalPlan {
.try_for_each(f),
// plans without expressions
LogicalPlan::EmptyRelation(_)
| LogicalPlan::NamedRelation(_)
// TODO: not sure if this should go here
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Subquery(_)
| LogicalPlan::SubqueryAlias(_)
| LogicalPlan::Limit(_)
Expand Down Expand Up @@ -430,8 +447,14 @@ impl LogicalPlan {
LogicalPlan::Ddl(ddl) => ddl.inputs(),
LogicalPlan::Unnest(Unnest { input, .. }) => vec![input],
LogicalPlan::Prepare(Prepare { input, .. }) => vec![input],
LogicalPlan::RecursiveQuery(RecursiveQuery {
static_term,
recursive_term,
..
}) => vec![static_term, recursive_term],
// plans without inputs
LogicalPlan::TableScan { .. }
| LogicalPlan::NamedRelation(_)
| LogicalPlan::Statement { .. }
| LogicalPlan::EmptyRelation { .. }
| LogicalPlan::Values { .. }
Expand Down Expand Up @@ -510,6 +533,9 @@ impl LogicalPlan {
cross.left.head_output_expr()
}
}
LogicalPlan::RecursiveQuery(RecursiveQuery { static_term, .. }) => {
static_term.head_output_expr()
}
LogicalPlan::Union(union) => Ok(Some(Expr::Column(
union.schema.fields()[0].qualified_column(),
))),
Expand All @@ -529,6 +555,7 @@ impl LogicalPlan {
}
LogicalPlan::Subquery(_) => Ok(None),
LogicalPlan::EmptyRelation(_)
| LogicalPlan::NamedRelation(_)
| LogicalPlan::Prepare(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
Expand Down Expand Up @@ -835,6 +862,14 @@ impl LogicalPlan {
};
Ok(LogicalPlan::Distinct(distinct))
}
LogicalPlan::RecursiveQuery(RecursiveQuery {
name, is_distinct, ..
}) => Ok(LogicalPlan::RecursiveQuery(RecursiveQuery {
name: name.clone(),
static_term: Arc::new(inputs[0].clone()),
recursive_term: Arc::new(inputs[1].clone()),
is_distinct: *is_distinct,
})),
LogicalPlan::Analyze(a) => {
assert!(expr.is_empty());
assert_eq!(inputs.len(), 1);
Expand Down Expand Up @@ -873,6 +908,7 @@ impl LogicalPlan {
}))
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::NamedRelation(_)
| LogicalPlan::Ddl(_)
| LogicalPlan::Statement(_) => {
// All of these plan types have no inputs / exprs so should not be called
Expand Down Expand Up @@ -1073,6 +1109,9 @@ impl LogicalPlan {
}),
LogicalPlan::TableScan(TableScan { fetch, .. }) => *fetch,
LogicalPlan::EmptyRelation(_) => Some(0),
// TODO: not sure if this is correct
LogicalPlan::NamedRelation(_) => None,
LogicalPlan::RecursiveQuery(_) => None,
LogicalPlan::Subquery(_) => None,
LogicalPlan::SubqueryAlias(SubqueryAlias { input, .. }) => input.max_rows(),
LogicalPlan::Limit(Limit { fetch, .. }) => *fetch,
Expand Down Expand Up @@ -1408,6 +1447,14 @@ impl LogicalPlan {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self.0 {
LogicalPlan::EmptyRelation(_) => write!(f, "EmptyRelation"),
LogicalPlan::NamedRelation(NamedRelation { name, .. }) => {
write!(f, "NamedRelation: {}", name)
}
LogicalPlan::RecursiveQuery(RecursiveQuery {
is_distinct, ..
}) => {
write!(f, "RecursiveQuery: is_distinct={}", is_distinct)
}
LogicalPlan::Values(Values { ref values, .. }) => {
let str_values: Vec<_> = values
.iter()
Expand Down Expand Up @@ -1718,6 +1765,28 @@ pub struct EmptyRelation {
pub schema: DFSchemaRef,
}

/// A named temporary relation with a known schema.
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct NamedRelation {
/// The relation name
pub name: String,
/// The schema description
pub schema: DFSchemaRef,
}

/// A variadic query operation
#[derive(Clone, PartialEq, Eq, Hash)]
pub struct RecursiveQuery {
/// Name of the query
pub name: String,
/// The static term
pub static_term: Arc<LogicalPlan>,
/// The recursive term
pub recursive_term: Arc<LogicalPlan>,
/// Distinction
pub is_distinct: bool,
}

/// Values expression. See
/// [Postgres VALUES](https://www.postgresql.org/docs/current/queries-values.html)
/// documentation for more details.
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/common_subexpr_eliminate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,8 @@ impl OptimizerRule for CommonSubexprEliminate {
| LogicalPlan::Dml(_)
| LogicalPlan::Copy(_)
| LogicalPlan::Unnest(_)
| LogicalPlan::NamedRelation(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Prepare(_) => {
// apply the optimization to all inputs of the plan
utils::optimize_children(self, plan, config)?
Expand Down
2 changes: 2 additions & 0 deletions datafusion/optimizer/src/optimize_projections.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,8 @@ fn optimize_projections(
.collect::<Vec<_>>()
}
LogicalPlan::EmptyRelation(_)
| LogicalPlan::NamedRelation(_)
| LogicalPlan::RecursiveQuery(_)
| LogicalPlan::Statement(_)
| LogicalPlan::Values(_)
| LogicalPlan::Extension(_)
Expand Down
5 changes: 4 additions & 1 deletion datafusion/physical-plan/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ name = "datafusion_physical_plan"
path = "src/lib.rs"

[dependencies]
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
ahash = { version = "0.8", default-features = false, features = [
"runtime-rng",
] }
arrow = { workspace = true }
arrow-array = { workspace = true }
arrow-buffer = { workspace = true }
Expand All @@ -55,6 +57,7 @@ parking_lot = { workspace = true }
pin-project-lite = "^0.2.7"
rand = { workspace = true }
tokio = { version = "1.28", features = ["sync", "fs", "parking_lot"] }
tokio-stream = { version = "0.1.14" }
uuid = { version = "^1.2", features = ["v4"] }

[dev-dependencies]
Expand Down
Loading
Loading