Skip to content

Commit 80ae34f

Browse files
committed
handle CTE name shadowing in resolve_table_references
1 parent 582834f commit 80ae34f

File tree

1 file changed

+33
-9
lines changed
  • datafusion/core/src/catalog

1 file changed

+33
-9
lines changed

datafusion/core/src/catalog/mod.rs

Lines changed: 33 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -302,7 +302,7 @@ impl CatalogProvider for MemoryCatalogProvider {
302302
///
303303
/// # Returns
304304
///
305-
/// A `(relations, ctes)` tuple, the first element contains table and view references and the second
305+
/// A `(table_refs, ctes)` tuple, the first element contains table and view references and the second
306306
/// element contains any CTE aliases that were defined and possibly referenced.
307307
///
308308
/// ## Example
@@ -365,6 +365,21 @@ pub fn resolve_table_references(
365365
ControlFlow::Continue(())
366366
}
367367

368+
fn pre_visit_query(&mut self, q: &Query) -> ControlFlow<Self::Break> {
369+
// The CTE name is not in scope when evaluating the CTE itself, so this is valid:
370+
// `WITH t AS (SELECT * FROM t) SELECT * FROM t`
371+
// Where the first `t` refers to a predefined table. So we are careful here
372+
// to visit the CTE first, before adding it to the set of known CTEs.
373+
//
374+
// This is a bit hackish as the CTE will be visited again as part of visiting `q`,
375+
// ideally there would be a `Visitor::post_visit_cte` hook.
376+
for cte in q.with.as_ref().map(|w| &w.cte_tables).into_iter().flatten() {
377+
cte.visit(self);
378+
self.ctes.insert(ObjectName(vec![cte.alias.name.clone()]));
379+
}
380+
ControlFlow::Continue(())
381+
}
382+
368383
fn pre_visit_statement(&mut self, statement: &Statement) -> ControlFlow<()> {
369384
if let Statement::ShowCreate {
370385
obj_type: ShowCreateObject::Table | ShowCreateObject::View,
@@ -374,12 +389,6 @@ pub fn resolve_table_references(
374389
self.insert_relation(obj_name)
375390
}
376391

377-
if let Statement::Query(q) = statement {
378-
for cte in q.with.as_ref().map(|w| &w.cte_tables).into_iter().flatten() {
379-
self.ctes.insert(ObjectName(vec![cte.alias.name.clone()]));
380-
}
381-
}
382-
383392
// SHOW statements will later be rewritten into a SELECT from the information_schema
384393
let requires_information_schema = matches!(
385394
statement,
@@ -433,7 +442,7 @@ pub fn resolve_table_references(
433442

434443
visit_statement(statement, &mut visitor);
435444

436-
let relations = visitor
445+
let table_refs = visitor
437446
.relations
438447
.into_iter()
439448
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
@@ -443,7 +452,7 @@ pub fn resolve_table_references(
443452
.into_iter()
444453
.map(|x| object_name_to_table_reference(x, enable_ident_normalization))
445454
.collect::<datafusion_common::Result<_>>()?;
446-
Ok((relations, ctes))
455+
Ok((table_refs, ctes))
447456
}
448457

449458
#[cfg(test)]
@@ -514,4 +523,19 @@ mod tests {
514523
let cat = Arc::new(MemoryCatalogProvider::new()) as Arc<dyn CatalogProvider>;
515524
assert!(cat.deregister_schema("foo", false).unwrap().is_none());
516525
}
526+
527+
#[test]
528+
fn resolve_table_references_shadowed_cte() {
529+
use datafusion_sql::parser::DFParser;
530+
531+
// An interesting edge case where the `t` name is used both as an ordinary table reference
532+
// and as a CTE reference.
533+
let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t";
534+
let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
535+
let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
536+
assert_eq!(table_refs.len(), 1);
537+
assert_eq!(ctes.len(), 1);
538+
assert_eq!(ctes[0].to_string(), "t");
539+
assert_eq!(table_refs[0].to_string(), "t");
540+
}
517541
}

0 commit comments

Comments
 (0)