@@ -27,6 +27,7 @@ use crate::catalog::schema::SchemaProvider;
2727use dashmap:: DashMap ;
2828use datafusion_common:: { exec_err, not_impl_err, Result } ;
2929use std:: any:: Any ;
30+ use std:: collections:: BTreeSet ;
3031use std:: ops:: ControlFlow ;
3132use std:: sync:: Arc ;
3233
@@ -296,25 +297,44 @@ impl CatalogProvider for MemoryCatalogProvider {
296297 }
297298}
298299
299- /// Resolve all table references in the SQL statement.
300+ /// Collects all tables and views referenced in the SQL statement. CTEs are collected separately.
301+ /// This can be used to determine which tables need to be in the catalog for a query to be planned.
302+ ///
303+ /// # Returns
304+ ///
305+ /// A `(relations, ctes)` tuple, the first element contains table and view references and the second
306+ /// element contains any CTE aliases that were defined and possibly referenced.
300307///
301308/// ## Example
302309///
303310/// ```
304- /// use datafusion_sql::parser::DFParser;
305- /// use datafusion::catalog::resolve_table_references;
306- ///
311+ /// # use datafusion_sql::parser::DFParser;
312+ /// # use datafusion::catalog::resolve_table_references;
307313/// let query = "SELECT a FROM foo where x IN (SELECT y FROM bar)";
308314/// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
309- /// let table_refs = resolve_table_references(&statement, true).unwrap();
315+ /// let ( table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
310316/// assert_eq!(table_refs.len(), 2);
311- /// assert_eq!(table_refs[0].to_string(), "foo");
312- /// assert_eq!(table_refs[1].to_string(), "bar");
317+ /// assert_eq!(table_refs[0].to_string(), "bar");
318+ /// assert_eq!(table_refs[1].to_string(), "foo");
319+ /// assert_eq!(ctes.len(), 0);
320+ /// ```
321+ ///
322+ /// ## Example with CTEs
323+ ///
324+ /// ```
325+ /// # use datafusion_sql::parser::DFParser;
326+ /// # use datafusion::catalog::resolve_table_references;
327+ /// let query = "with my_cte as (values (1), (2)) SELECT * from my_cte;";
328+ /// let statement = DFParser::parse_sql(query).unwrap().pop_back().unwrap();
329+ /// let (table_refs, ctes) = resolve_table_references(&statement, true).unwrap();
330+ /// assert_eq!(table_refs.len(), 0);
331+ /// assert_eq!(ctes.len(), 1);
332+ /// assert_eq!(ctes[0].to_string(), "my_cte");
313333/// ```
314334pub fn resolve_table_references (
315335 statement : & datafusion_sql:: parser:: Statement ,
316336 enable_ident_normalization : bool ,
317- ) -> datafusion_common:: Result < Vec < TableReference > > {
337+ ) -> datafusion_common:: Result < ( Vec < TableReference > , Vec < TableReference > ) > {
318338 use crate :: sql:: planner:: object_name_to_table_reference;
319339 use datafusion_sql:: parser:: {
320340 CopyToSource , CopyToStatement , Statement as DFStatement ,
@@ -323,24 +343,25 @@ pub fn resolve_table_references(
323343 use information_schema:: INFORMATION_SCHEMA_TABLES ;
324344 use sqlparser:: ast:: * ;
325345
326- // Getting `TableProviders` is async but planing is not -- thus pre-fetch
327- // table providers for all relations referenced in this query
328- let mut relations = hashbrown:: HashSet :: with_capacity ( 10 ) ;
329-
330- struct RelationVisitor < ' a > ( & ' a mut hashbrown:: HashSet < ObjectName > ) ;
346+ struct RelationVisitor {
347+ relations : BTreeSet < ObjectName > ,
348+ ctes : BTreeSet < ObjectName > ,
349+ }
331350
332- impl < ' a > RelationVisitor < ' a > {
333- /// Record that `relation` was used in this statement
334- fn insert ( & mut self , relation : & ObjectName ) {
335- self . 0 . get_or_insert_with ( relation, |_| relation. clone ( ) ) ;
351+ impl RelationVisitor {
352+ /// Record the reference to `relation`, if it's not a CTE reference.
353+ fn insert_relation ( & mut self , relation : & ObjectName ) {
354+ if !self . relations . contains ( & relation) && !self . ctes . contains ( & relation) {
355+ self . relations . insert ( relation. clone ( ) ) ;
356+ }
336357 }
337358 }
338359
339- impl < ' a > Visitor for RelationVisitor < ' a > {
360+ impl Visitor for RelationVisitor {
340361 type Break = ( ) ;
341362
342363 fn pre_visit_relation ( & mut self , relation : & ObjectName ) -> ControlFlow < ( ) > {
343- self . insert ( relation) ;
364+ self . insert_relation ( relation) ;
344365 ControlFlow :: Continue ( ( ) )
345366 }
346367
@@ -350,7 +371,13 @@ pub fn resolve_table_references(
350371 obj_name,
351372 } = statement
352373 {
353- self . insert ( obj_name)
374+ self . insert_relation ( obj_name)
375+ }
376+
377+ if let Statement :: Query ( q) = statement {
378+ for cte in q. with . as_ref ( ) . map ( |w| & w. cte_tables ) . into_iter ( ) . flatten ( ) {
379+ self . ctes . insert ( ObjectName ( vec ! [ cte. alias. name. clone( ) ] ) ) ;
380+ }
354381 }
355382
356383 // SHOW statements will later be rewritten into a SELECT from the information_schema
@@ -367,7 +394,7 @@ pub fn resolve_table_references(
367394 ) ;
368395 if requires_information_schema {
369396 for s in INFORMATION_SCHEMA_TABLES {
370- self . 0 . insert ( ObjectName ( vec ! [
397+ self . relations . insert ( ObjectName ( vec ! [
371398 Ident :: new( INFORMATION_SCHEMA ) ,
372399 Ident :: new( * s) ,
373400 ] ) ) ;
@@ -377,20 +404,24 @@ pub fn resolve_table_references(
377404 }
378405 }
379406
380- let mut visitor = RelationVisitor ( & mut relations) ;
381- fn visit_statement ( statement : & DFStatement , visitor : & mut RelationVisitor < ' _ > ) {
407+ let mut visitor = RelationVisitor {
408+ relations : BTreeSet :: new ( ) ,
409+ ctes : BTreeSet :: new ( ) ,
410+ } ;
411+
412+ fn visit_statement ( statement : & DFStatement , visitor : & mut RelationVisitor ) {
382413 match statement {
383414 DFStatement :: Statement ( s) => {
384415 let _ = s. as_ref ( ) . visit ( visitor) ;
385416 }
386417 DFStatement :: CreateExternalTable ( table) => {
387418 visitor
388- . 0
419+ . relations
389420 . insert ( ObjectName ( vec ! [ Ident :: from( table. name. as_str( ) ) ] ) ) ;
390421 }
391422 DFStatement :: CopyTo ( CopyToStatement { source, .. } ) => match source {
392423 CopyToSource :: Relation ( table_name) => {
393- visitor. insert ( table_name) ;
424+ visitor. insert_relation ( table_name) ;
394425 }
395426 CopyToSource :: Query ( query) => {
396427 query. visit ( visitor) ;
@@ -402,10 +433,17 @@ pub fn resolve_table_references(
402433
403434 visit_statement ( statement, & mut visitor) ;
404435
405- relations
436+ let relations = visitor
437+ . relations
438+ . into_iter ( )
439+ . map ( |x| object_name_to_table_reference ( x, enable_ident_normalization) )
440+ . collect :: < datafusion_common:: Result < _ > > ( ) ?;
441+ let ctes = visitor
442+ . ctes
406443 . into_iter ( )
407444 . map ( |x| object_name_to_table_reference ( x, enable_ident_normalization) )
408- . collect :: < datafusion_common:: Result < _ > > ( )
445+ . collect :: < datafusion_common:: Result < _ > > ( ) ?;
446+ Ok ( ( relations, ctes) )
409447}
410448
411449#[ cfg( test) ]
0 commit comments