Skip to content

Commit a15c149

Browse files
committed
collect CTEs separately in resolve_table_references
1 parent ab7e7ed commit a15c149

File tree

2 files changed

+75
-29
lines changed

2 files changed

+75
-29
lines changed

datafusion/core/src/catalog/mod.rs

Lines changed: 65 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ use crate::catalog::schema::SchemaProvider;
2727
use dashmap::DashMap;
2828
use datafusion_common::{exec_err, not_impl_err, Result};
2929
use std::any::Any;
30+
use std::collections::BTreeSet;
3031
use std::ops::ControlFlow;
3132
use std::sync::Arc;
3233

@@ -296,25 +297,44 @@ impl CatalogProvider for MemoryCatalogProvider {
296297
}
297298
}
298299

299-
/// Resolve all table references in the SQL statement.
300+
/// Collects all tables and views referenced in the SQL statement. CTEs are collected separately.
301+
/// This can be used to determine which tables need to be in the catalog for a query to be planned.
302+
///
303+
/// # Returns
304+
///
305+
/// A `(relations, ctes)` tuple, the first element contains table and view references and the second
306+
/// element contains any CTE aliases that were defined and possibly referenced.
300307
///
301308
/// ## Example
302309
///
303310
/// ```
304-
/// use datafusion_sql::parser::DFParser;
305-
/// use datafusion::catalog::resolve_table_references;
306-
///
311+
/// # use datafusion_sql::parser::DFParser;
312+
/// # use datafusion::catalog::resolve_table_references;
307313
/// let query = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
308314
/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
309-
/// let table_refs = resolve_table_references(&statement, true).unwrap();
315+
/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
310316
/// assert_eq!(table_refs.len(), 2);
311-
/// assert_eq!(table_refs[0].to_string(), "foo");
312-
/// assert_eq!(table_refs[1].to_string(), "bar");
317+
/// assert_eq!(table_refs[0].to_string(), "bar");
318+
/// assert_eq!(table_refs[1].to_string(), "foo");
319+
/// assert_eq!(ctes.len(), 0);
320+
/// ```
321+
///
322+
/// ## Example with CTEs
323+
///
324+
/// ```
325+
/// # use datafusion_sql::parser::DFParser;
326+
/// # use datafusion::catalog::resolve_table_references;
327+
/// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;";
328+
/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
329+
/// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
330+
/// assert_eq!(table_refs.len(), 0);
331+
/// assert_eq!(ctes.len(), 1);
332+
/// assert_eq!(ctes[0].to_string(), "my_cte");
313333
/// ```
314334
pub fn resolve_table_references(
315335
statement: &datafusion_sql::parser::Statement,
316336
enable_ident_normalization: bool,
317-
) -> datafusion_common::Result<Vec<TableReference>> {
337+
) -> datafusion_common::Result<(Vec<TableReference>, Vec<TableReference>)> {
318338
use crate::sql::planner::object_name_to_table_reference;
319339
use datafusion_sql::parser::{
320340
CopyToSource, CopyToStatement, Statement as DFStatement,
@@ -323,24 +343,25 @@ pub fn resolve_table_references(
323343
use information_schema::INFORMATION_SCHEMA_TABLES;
324344
use sqlparser::ast::*;
325345

326-
// Getting `TableProviders` is async but planing is not -- thus pre-fetch
327-
// table providers for all relations referenced in this query
328-
let mut relations = hashbrown::HashSet::with_capacity(10);
329-
330-
struct RelationVisitor<'a>(&'a mut hashbrown::HashSet<ObjectName>);
346+
struct RelationVisitor {
347+
relations: BTreeSet<ObjectName>,
348+
ctes: BTreeSet<ObjectName>,
349+
}
331350

332-
impl<'a> RelationVisitor<'a> {
333-
/// Record that `relation` was used in this statement
334-
fn insert(&mut self, relation: &ObjectName) {
335-
self.0.get_or_insert_with(relation, |_| relation.clone());
351+
impl RelationVisitor {
352+
/// Record the reference to `relation`, if it's not a CTE reference.
353+
fn insert_relation(&mut self, relation: &ObjectName) {
354+
if !self.relations.contains(&relation) && !self.ctes.contains(&relation) {
355+
self.relations.insert(relation.clone());
356+
}
336357
}
337358
}
338359

339-
impl<'a> Visitor for RelationVisitor<'a> {
360+
impl Visitor for RelationVisitor {
340361
type Break = ();
341362

342363
fn pre_visit_relation(&mut self, relation: &ObjectName) -> ControlFlow<()> {
343-
self.insert(relation);
364+
self.insert_relation(relation);
344365
ControlFlow::Continue(())
345366
}
346367

@@ -350,7 +371,13 @@ pub fn resolve_table_references(
350371
obj_name,
351372
} = statement
352373
{
353-
self.insert(obj_name)
374+
self.insert_relation(obj_name)
375+
}
376+
377+
if let Statement::Query(q) = statement {
378+
for cte in q.with.as_ref().map(|w| &w.cte_tables).into_iter().flatten() {
379+
self.ctes.insert(ObjectName(vec![cte.alias.name.clone()]));
380+
}
354381
}
355382

356383
// SHOW statements will later be rewritten into a SELECT from the information_schema
@@ -367,7 +394,7 @@ pub fn resolve_table_references(
367394
);
368395
if requires_information_schema {
369396
for s in INFORMATION_SCHEMA_TABLES {
370-
self.0.insert(ObjectName(vec![
397+
self.relations.insert(ObjectName(vec![
371398
Ident::new(INFORMATION_SCHEMA),
372399
Ident::new(*s),
373400
]));
@@ -377,20 +404,24 @@ pub fn resolve_table_references(
377404
}
378405
}
379406

380-
let mut visitor = RelationVisitor(&mut relations);
381-
fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor<'_>) {
407+
let mut visitor = RelationVisitor {
408+
relations: BTreeSet::new(),
409+
ctes: BTreeSet::new(),
410+
};
411+
412+
fn visit_statement(statement: &DFStatement, visitor: &mut RelationVisitor) {
382413
match statement {
383414
DFStatement::Statement(s) => {
384415
let _ = s.as_ref().visit(visitor);
385416
}
386417
DFStatement::CreateExternalTable(table) => {
387418
visitor
388-
.0
419+
.relations
389420
.insert(ObjectName(vec![Ident::from(table.name.as_str())]));
390421
}
391422
DFStatement::CopyTo(CopyToStatement { source, .. }) => match source {
392423
CopyToSource::Relation(table_name) => {
393-
visitor.insert(table_name);
424+
visitor.insert_relation(table_name);
394425
}
395426
CopyToSource::Query(query) => {
396427
query.visit(visitor);
@@ -402,10 +433,17 @@ pub fn resolve_table_references(
402433

403434
visit_statement(statement, &mut visitor);
404435

405-
relations
436+
let relations = visitor
437+
.relations
438+
.into_iter()
439+
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
440+
.collect::<datafusion_common::Result<_>>()?;
441+
let ctes = visitor
442+
.ctes
406443
.into_iter()
407444
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
408-
.collect::<datafusion_common::Result<_>>()
445+
.collect::<datafusion_common::Result<_>>()?;
446+
Ok((relations, ctes))
409447
}
410448

411449
#[cfg(test)]

datafusion/core/src/execution/session_state.rs

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -490,14 +490,22 @@ impl SessionState {
490490
Ok(statement)
491491
}
492492

493-
/// Resolve all table references in the SQL statement.
493+
/// Resolve all table references in the SQL statement. Does not include CTE references.
494+
///
495+
/// See [`catalog::resolve_table_references`] for more information.
496+
///
497+
/// [`catalog::resolve_table_references`]: crate::catalog::resolve_table_references
494498
pub fn resolve_table_references(
495499
&self,
496500
statement: &datafusion_sql::parser::Statement,
497501
) -> datafusion_common::Result<Vec<TableReference>> {
498502
let enable_ident_normalization =
499503
self.config.options().sql_parser.enable_ident_normalization;
500-
crate::catalog::resolve_table_references(statement, enable_ident_normalization)
504+
let (table_refs, _) = crate::catalog::resolve_table_references(
505+
statement,
506+
enable_ident_normalization,
507+
)?;
508+
Ok(table_refs)
501509
}
502510

503511
/// Convert an AST Statement into a LogicalPlan

0 commit comments

Comments
 (0)