@@ -27,6 +27,8 @@ use crate::catalog::schema::SchemaProvider;
2727use dashmap:: DashMap ;
2828use datafusion_common:: { exec_err, not_impl_err, Result } ;
2929use std:: any:: Any ;
30+ use std:: collections:: BTreeSet ;
31+ use std:: ops:: ControlFlow ;
3032use std:: sync:: Arc ;
3133
3234/// Represent a list of named [`CatalogProvider`]s.
@@ -157,11 +159,11 @@ impl CatalogProviderList for MemoryCatalogProviderList {
157159/// access required to read table details (e.g. statistics).
158160///
159161/// The pattern that DataFusion itself uses to plan SQL queries is to walk over
160- /// the query to [find all schema / table references in an `async` function ],
162+ /// the query to [find all table references],
161163/// performing required remote catalog in parallel, and then plans the query
162164/// using that snapshot.
163165///
164- /// [find all schema / table references in an `async` function ]: crate::execution::context::SessionState:: resolve_table_references
166+ /// [find all table references]: resolve_table_references
165167///
166168/// # Example Catalog Implementations
167169///
@@ -295,6 +297,182 @@ impl CatalogProvider for MemoryCatalogProvider {
295297 }
296298}
297299
300+ /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately.
301+ /// This can be used to determine which tables need to be in the catalog for a query to be planned.
302+ ///
303+ /// # Returns
304+ ///
305+ /// A `(table_refs, ctes)` tuple, the first element contains table and view references and the second
306+ /// element contains any CTE aliases that were defined and possibly referenced.
307+ ///
308+ /// ## Example
309+ ///
310+ /// ```
311+ /// # use datafusion_sql::parser::DFParser;
312+ /// # use datafusion::catalog::resolve_table_references;
313+ /// let query = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
314+ /// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
315+ /// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
316+ /// assert_eq!(table_refs.len(), 2);
317+ /// assert_eq!(table_refs[0].to_string(), "bar");
318+ /// assert_eq!(table_refs[1].to_string(), "foo");
319+ /// assert_eq!(ctes.len(), 0);
320+ /// ```
321+ ///
322+ /// ## Example with CTEs
323+ ///
324+ /// ```
325+ /// # use datafusion_sql::parser::DFParser;
326+ /// # use datafusion::catalog::resolve_table_references;
327+ /// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;";
328+ /// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
329+ /// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
330+ /// assert_eq!(table_refs.len(), 0);
331+ /// assert_eq!(ctes.len(), 1);
332+ /// assert_eq!(ctes[0].to_string(), "my_cte");
333+ /// ```
334+ pub fn resolve_table_references (
335+ statement : & datafusion_sql:: parser:: Statement ,
336+ enable_ident_normalization : bool ,
337+ ) -> datafusion_common:: Result < ( Vec < TableReference > , Vec < TableReference > ) > {
338+ use crate :: sql:: planner:: object_name_to_table_reference;
339+ use datafusion_sql:: parser:: {
340+ CopyToSource , CopyToStatement , Statement as DFStatement ,
341+ } ;
342+ use information_schema:: INFORMATION_SCHEMA ;
343+ use information_schema:: INFORMATION_SCHEMA_TABLES ;
344+ use sqlparser:: ast:: * ;
345+
346+ struct RelationVisitor {
347+ relations : BTreeSet < ObjectName > ,
348+ all_ctes : BTreeSet < ObjectName > ,
349+ ctes_in_scope : Vec < ObjectName > ,
350+ }
351+
352+ impl RelationVisitor {
353+ /// Record the reference to `relation`, if it's not a CTE reference.
354+ fn insert_relation ( & mut self , relation : & ObjectName ) {
355+ if !self . relations . contains ( relation)
356+ && !self . ctes_in_scope . contains ( relation)
357+ {
358+ self . relations . insert ( relation. clone ( ) ) ;
359+ }
360+ }
361+ }
362+
363+ impl Visitor for RelationVisitor {
364+ type Break = ( ) ;
365+
366+ fn pre_visit_relation ( & mut self , relation : & ObjectName ) -> ControlFlow < ( ) > {
367+ self . insert_relation ( relation) ;
368+ ControlFlow :: Continue ( ( ) )
369+ }
370+
371+ fn pre_visit_query ( & mut self , q : & Query ) -> ControlFlow < Self :: Break > {
372+ if let Some ( with) = & q. with {
373+ for cte in & with. cte_tables {
374+ // The non-recursive CTE name is not in scope when evaluating the CTE itself, so this is valid:
375+ // `WITH t AS (SELECT * FROM t) SELECT * FROM t`
376+ // Where the first `t` refers to a predefined table. So we are careful here
377+ // to visit the CTE first, before putting it in scope.
378+ if !with. recursive {
379+ // This is a bit hackish as the CTE will be visited again as part of visiting `q`,
380+ // but thankfully `insert_relation` is idempotent.
381+ cte. visit ( self ) ;
382+ }
383+ self . ctes_in_scope
384+ . push ( ObjectName ( vec ! [ cte. alias. name. clone( ) ] ) ) ;
385+ }
386+ }
387+ ControlFlow :: Continue ( ( ) )
388+ }
389+
390+ fn post_visit_query ( & mut self , q : & Query ) -> ControlFlow < Self :: Break > {
391+ if let Some ( with) = & q. with {
392+ for _ in & with. cte_tables {
393+ // Unwrap: We just pushed these in `pre_visit_query`
394+ self . all_ctes . insert ( self . ctes_in_scope . pop ( ) . unwrap ( ) ) ;
395+ }
396+ }
397+ ControlFlow :: Continue ( ( ) )
398+ }
399+
400+ fn pre_visit_statement ( & mut self , statement : & Statement ) -> ControlFlow < ( ) > {
401+ if let Statement :: ShowCreate {
402+ obj_type : ShowCreateObject :: Table | ShowCreateObject :: View ,
403+ obj_name,
404+ } = statement
405+ {
406+ self . insert_relation ( obj_name)
407+ }
408+
409+ // SHOW statements will later be rewritten into a SELECT from the information_schema
410+ let requires_information_schema = matches ! (
411+ statement,
412+ Statement :: ShowFunctions { .. }
413+ | Statement :: ShowVariable { .. }
414+ | Statement :: ShowStatus { .. }
415+ | Statement :: ShowVariables { .. }
416+ | Statement :: ShowCreate { .. }
417+ | Statement :: ShowColumns { .. }
418+ | Statement :: ShowTables { .. }
419+ | Statement :: ShowCollation { .. }
420+ ) ;
421+ if requires_information_schema {
422+ for s in INFORMATION_SCHEMA_TABLES {
423+ self . relations . insert ( ObjectName ( vec ! [
424+ Ident :: new( INFORMATION_SCHEMA ) ,
425+ Ident :: new( * s) ,
426+ ] ) ) ;
427+ }
428+ }
429+ ControlFlow :: Continue ( ( ) )
430+ }
431+ }
432+
433+ let mut visitor = RelationVisitor {
434+ relations : BTreeSet :: new ( ) ,
435+ all_ctes : BTreeSet :: new ( ) ,
436+ ctes_in_scope : vec ! [ ] ,
437+ } ;
438+
439+ fn visit_statement ( statement : & DFStatement , visitor : & mut RelationVisitor ) {
440+ match statement {
441+ DFStatement :: Statement ( s) => {
442+ let _ = s. as_ref ( ) . visit ( visitor) ;
443+ }
444+ DFStatement :: CreateExternalTable ( table) => {
445+ visitor
446+ . relations
447+ . insert ( ObjectName ( vec ! [ Ident :: from( table. name. as_str( ) ) ] ) ) ;
448+ }
449+ DFStatement :: CopyTo ( CopyToStatement { source, .. } ) => match source {
450+ CopyToSource :: Relation ( table_name) => {
451+ visitor. insert_relation ( table_name) ;
452+ }
453+ CopyToSource :: Query ( query) => {
454+ query. visit ( visitor) ;
455+ }
456+ } ,
457+ DFStatement :: Explain ( explain) => visit_statement ( & explain. statement , visitor) ,
458+ }
459+ }
460+
461+ visit_statement ( statement, & mut visitor) ;
462+
463+ let table_refs = visitor
464+ . relations
465+ . into_iter ( )
466+ . map ( |x| object_name_to_table_reference ( x, enable_ident_normalization) )
467+ . collect :: < datafusion_common:: Result < _ > > ( ) ?;
468+ let ctes = visitor
469+ . all_ctes
470+ . into_iter ( )
471+ . map ( |x| object_name_to_table_reference ( x, enable_ident_normalization) )
472+ . collect :: < datafusion_common:: Result < _ > > ( ) ?;
473+ Ok ( ( table_refs, ctes) )
474+ }
475+
298476#[ cfg( test) ]
299477mod tests {
300478 use super :: * ;
@@ -363,4 +541,61 @@ mod tests {
363541 let cat = Arc :: new ( MemoryCatalogProvider :: new ( ) ) as Arc < dyn CatalogProvider > ;
364542 assert ! ( cat. deregister_schema( "foo" , false ) . unwrap( ) . is_none( ) ) ;
365543 }
544+
545+ #[ test]
546+ fn resolve_table_references_shadowed_cte ( ) {
547+ use datafusion_sql:: parser:: DFParser ;
548+
549+ // An interesting edge case where the `t` name is used both as an ordinary table reference
550+ // and as a CTE reference.
551+ let query = "WITH t AS (SELECT * FROM t) SELECT * FROM t" ;
552+ let statement = DFParser :: parse_sql ( query) . unwrap ( ) . pop_back ( ) . unwrap ( ) ;
553+ let ( table_refs, ctes) = resolve_table_references ( & statement, true ) . unwrap ( ) ;
554+ assert_eq ! ( table_refs. len( ) , 1 ) ;
555+ assert_eq ! ( ctes. len( ) , 1 ) ;
556+ assert_eq ! ( ctes[ 0 ] . to_string( ) , "t" ) ;
557+ assert_eq ! ( table_refs[ 0 ] . to_string( ) , "t" ) ;
558+
559+ // UNION is a special case where the CTE is not in scope for the second branch.
560+ let query = "(with t as (select 1) select * from t) union (select * from t)" ;
561+ let statement = DFParser :: parse_sql ( query) . unwrap ( ) . pop_back ( ) . unwrap ( ) ;
562+ let ( table_refs, ctes) = resolve_table_references ( & statement, true ) . unwrap ( ) ;
563+ assert_eq ! ( table_refs. len( ) , 1 ) ;
564+ assert_eq ! ( ctes. len( ) , 1 ) ;
565+ assert_eq ! ( ctes[ 0 ] . to_string( ) , "t" ) ;
566+ assert_eq ! ( table_refs[ 0 ] . to_string( ) , "t" ) ;
567+
568+ // Nested CTEs are also handled.
569+ // Here the first `u` is a CTE, but the second `u` is a table reference.
570+ // While `t` is always a CTE.
571+ let query = "(with t as (with u as (select 1) select * from u) select * from u cross join t)" ;
572+ let statement = DFParser :: parse_sql ( query) . unwrap ( ) . pop_back ( ) . unwrap ( ) ;
573+ let ( table_refs, ctes) = resolve_table_references ( & statement, true ) . unwrap ( ) ;
574+ assert_eq ! ( table_refs. len( ) , 1 ) ;
575+ assert_eq ! ( ctes. len( ) , 2 ) ;
576+ assert_eq ! ( ctes[ 0 ] . to_string( ) , "t" ) ;
577+ assert_eq ! ( ctes[ 1 ] . to_string( ) , "u" ) ;
578+ assert_eq ! ( table_refs[ 0 ] . to_string( ) , "u" ) ;
579+ }
580+
581+ #[ test]
582+ fn resolve_table_references_recursive_cte ( ) {
583+ use datafusion_sql:: parser:: DFParser ;
584+
585+ let query = "
586+ WITH RECURSIVE nodes AS (
587+ SELECT 1 as id
588+ UNION ALL
589+ SELECT id + 1 as id
590+ FROM nodes
591+ WHERE id < 10
592+ )
593+ SELECT * FROM nodes
594+ " ;
595+ let statement = DFParser :: parse_sql ( query) . unwrap ( ) . pop_back ( ) . unwrap ( ) ;
596+ let ( table_refs, ctes) = resolve_table_references ( & statement, true ) . unwrap ( ) ;
597+ assert_eq ! ( table_refs. len( ) , 0 ) ;
598+ assert_eq ! ( ctes. len( ) , 1 ) ;
599+ assert_eq ! ( ctes[ 0 ] . to_string( ) , "nodes" ) ;
600+ }
366601}
0 commit comments