Skip to content

Commit 6b7d21b

Browse files
committed
feat: add DefaultTableSource::wrap and unwrap
These replace `provider_as_source` and `source_as_provider`, and are somewhat more explicit.
1 parent 0d66240 commit 6b7d21b

File tree

18 files changed

+167
-66
lines changed

18 files changed

+167
-66
lines changed

datafusion-examples/examples/custom_datasource.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ use async_trait::async_trait;
2525
use datafusion::arrow::array::{UInt64Builder, UInt8Builder};
2626
use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
2727
use datafusion::arrow::record_batch::RecordBatch;
28-
use datafusion::datasource::{provider_as_source, TableProvider, TableType};
28+
use datafusion::datasource::{DefaultTableSource, TableProvider, TableType};
2929
use datafusion::error::Result;
3030
use datafusion::execution::context::TaskContext;
3131
use datafusion::logical_expr::LogicalPlanBuilder;
@@ -66,7 +66,7 @@ async fn search_accounts(
6666
// create logical plan composed of a single TableScan
6767
let logical_plan = LogicalPlanBuilder::scan_with_filters(
6868
"accounts",
69-
provider_as_source(Arc::new(db)),
69+
DefaultTableSource::wrap(Arc::new(db)),
7070
None,
7171
vec![],
7272
)?

datafusion/catalog/src/default_table_source.rs

Lines changed: 86 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,90 @@ pub struct DefaultTableSource {
4141
}
4242

4343
impl DefaultTableSource {
44-
/// Create a new DefaultTableSource to wrap a TableProvider
45-
pub fn new(table_provider: Arc<dyn TableProvider>) -> Self {
46-
Self { table_provider }
44+
/// Wraps a [TableProvider] as a [TableSource], to be used in planning.
45+
///
46+
/// # Example
47+
/// ```
48+
/// # use std::sync::Arc;
49+
/// # use std::any::Any;
50+
/// # use arrow::datatypes::{Schema, SchemaRef};
51+
/// # use datafusion_expr::{Expr, TableType, TableSource};
52+
/// # use datafusion_physical_plan::ExecutionPlan;
53+
/// # use datafusion_common::Result;
54+
/// # use datafusion_catalog::{TableProvider, default_table_source::DefaultTableSource};
55+
/// # use datafusion_session::Session;
56+
/// # use async_trait::async_trait;
57+
///
58+
/// # #[derive(Debug, Eq, PartialEq)]
59+
/// # struct MyTableProvider {};
60+
///
61+
/// # #[async_trait]
62+
/// # impl TableProvider for MyTableProvider {
63+
/// # fn as_any(&self) -> &dyn Any { self }
64+
/// # fn schema(&self) -> SchemaRef { Arc::new(Schema::empty()) }
65+
/// # fn table_type(&self) -> TableType { TableType::Base }
66+
/// # async fn scan(
67+
/// # &self,
68+
/// # _: &dyn Session,
69+
/// # _: Option<&Vec<usize>>,
70+
/// # _: &[Expr],
71+
/// # _: Option<usize>,
72+
/// # ) -> Result<Arc<dyn ExecutionPlan>> {
73+
/// # unimplemented!()
74+
/// # }
75+
/// # }
76+
///
77+
/// let provider = Arc::new(MyTableProvider {});
78+
/// let table_source = DefaultTableSource::wrap(provider);
79+
/// ```
80+
pub fn wrap(table_provider: Arc<dyn TableProvider>) -> Arc<Self> {
81+
Arc::new(Self { table_provider })
4782
}
4883

49-
/// Attempt to downcast a TableSource to DefaultTableSource and access the
50-
/// TableProvider. This will only work with a TableSource created by DataFusion.
51-
pub fn unwrap_provider<T: TableProvider + 'static>(
84+
/// Attempt to downcast a `TableSource` to `DefaultTableSource` and access
85+
/// the [TableProvider]. This will only work with a [TableSource] created
86+
/// by [`DefaultTableSource::wrap`].
87+
///
88+
/// # Example
89+
/// ```
90+
/// # use std::sync::Arc;
91+
/// # use std::any::Any;
92+
/// # use arrow::datatypes::{Schema, SchemaRef};
93+
/// # use datafusion_common::Result;
94+
/// # use datafusion_expr::{Expr, TableType, TableSource};
95+
/// # use datafusion_physical_plan::ExecutionPlan;
96+
/// # use datafusion_catalog::{TableProvider, default_table_source::DefaultTableSource};
97+
/// # use datafusion_session::Session;
98+
/// # use async_trait::async_trait;
99+
///
100+
/// # #[derive(Debug, Eq, PartialEq)]
101+
/// # struct MyTableProvider {}
102+
///
103+
/// # #[async_trait]
104+
/// # impl TableProvider for MyTableProvider {
105+
/// # fn as_any(&self) -> &dyn Any { self }
106+
/// # fn schema(&self) -> SchemaRef { Arc::new(Schema::empty()) }
107+
/// # fn table_type(&self) -> TableType { TableType::Base }
108+
/// # async fn scan(
109+
/// # &self,
110+
/// # _: &dyn Session,
111+
/// # _: Option<&Vec<usize>>,
112+
/// # _: &[Expr],
113+
/// # _: Option<usize>,
114+
/// # ) -> Result<Arc<dyn ExecutionPlan>> {
115+
/// # unimplemented!()
116+
/// # }
117+
/// # }
118+
///
119+
/// # fn example() -> Result<()> {
120+
/// let provider = Arc::new(MyTableProvider {});
121+
/// let table_source: Arc<dyn TableSource> = DefaultTableSource::wrap(provider.clone());
122+
/// let unwrapped = DefaultTableSource::unwrap::<MyTableProvider>(&table_source)?;
123+
/// assert_eq!(provider.as_ref(), unwrapped);
124+
/// # Ok(())
125+
/// # }
126+
/// ```
127+
pub fn unwrap<T: TableProvider + 'static>(
52128
source: &Arc<dyn TableSource>,
53129
) -> datafusion_common::Result<&T> {
54130
if let Some(source) = source
@@ -108,14 +184,16 @@ impl TableSource for DefaultTableSource {
108184
}
109185

110186
/// Wrap a TableProvider as a TableSource.
187+
#[deprecated(note = "use DefaultTableSource::wrap instead")]
111188
pub fn provider_as_source(
112189
table_provider: Arc<dyn TableProvider>,
113190
) -> Arc<dyn TableSource> {
114-
Arc::new(DefaultTableSource::new(table_provider))
191+
DefaultTableSource::wrap(table_provider)
115192
}
116193

117194
/// Attempt to downcast a TableSource to DefaultTableSource and access the
118195
/// TableProvider. This will only work with a TableSource created by DataFusion.
196+
#[deprecated(note = "use DefaultTableSource::unwrap instead")]
119197
pub fn source_as_provider(
120198
source: &Arc<dyn TableSource>,
121199
) -> datafusion_common::Result<Arc<dyn TableProvider>> {
@@ -163,6 +241,6 @@ fn preserves_table_type() {
163241
}
164242
}
165243

166-
let table_source = DefaultTableSource::new(Arc::new(TestTempTable));
244+
let table_source = DefaultTableSource::wrap(Arc::new(TestTempTable));
167245
assert_eq!(table_source.table_type(), TableType::Temporary);
168246
}

datafusion/core/src/dataframe/mod.rs

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -25,9 +25,7 @@ use crate::arrow::util::pretty;
2525
use crate::datasource::file_format::csv::CsvFormatFactory;
2626
use crate::datasource::file_format::format_as_file_type;
2727
use crate::datasource::file_format::json::JsonFormatFactory;
28-
use crate::datasource::{
29-
provider_as_source, DefaultTableSource, MemTable, TableProvider,
30-
};
28+
use crate::datasource::{DefaultTableSource, MemTable, TableProvider};
3129
use crate::error::Result;
3230
use crate::execution::context::{SessionState, TaskContext};
3331
use crate::execution::FunctionRegistry;
@@ -1088,7 +1086,7 @@ impl DataFrame {
10881086

10891087
let plan = LogicalPlanBuilder::scan(
10901088
UNNAMED_TABLE,
1091-
provider_as_source(Arc::new(provider)),
1089+
DefaultTableSource::wrap(Arc::new(provider)),
10921090
None,
10931091
)?
10941092
.build()?;
@@ -1845,7 +1843,7 @@ impl DataFrame {
18451843
_ => plan_err!("No table named '{table_name}'"),
18461844
}?;
18471845

1848-
let target = Arc::new(DefaultTableSource::new(target));
1846+
let target = DefaultTableSource::wrap(target);
18491847

18501848
let plan = LogicalPlanBuilder::insert_into(
18511849
plan,

datafusion/core/src/datasource/listing/table.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1517,7 +1517,7 @@ mod tests {
15171517
use crate::{
15181518
datasource::{
15191519
file_format::csv::CsvFormat, file_format::json::JsonFormat,
1520-
provider_as_source, DefaultTableSource, MemTable,
1520+
DefaultTableSource, MemTable,
15211521
},
15221522
execution::options::ArrowReadOptions,
15231523
test::{
@@ -2222,9 +2222,9 @@ mod tests {
22222222
)?);
22232223
session_ctx.register_table("source", source_table.clone())?;
22242224
// Convert the source table into a provider so that it can be used in a query
2225-
let source = provider_as_source(source_table);
2225+
let source = DefaultTableSource::wrap(source_table);
22262226
let target = session_ctx.table_provider("t").await?;
2227-
let target = Arc::new(DefaultTableSource::new(target));
2227+
let target = DefaultTableSource::wrap(target);
22282228
// Create a table scan logical plan to read from the source table
22292229
let scan_plan = LogicalPlanBuilder::scan("source", source, None)?
22302230
.filter(filter_predicate)?

datafusion/core/src/datasource/memory_test.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@
1818
#[cfg(test)]
1919
mod tests {
2020

21+
use crate::datasource::DefaultTableSource;
2122
use crate::datasource::MemTable;
22-
use crate::datasource::{provider_as_source, DefaultTableSource};
2323
use crate::physical_plan::collect;
2424
use crate::prelude::SessionContext;
2525
use arrow::array::{AsArray, Int32Array};
@@ -278,12 +278,12 @@ mod tests {
278278
// Create and register the initial table with the provided schema and data
279279
let initial_table = Arc::new(MemTable::try_new(schema.clone(), initial_data)?);
280280
session_ctx.register_table("t", initial_table.clone())?;
281-
let target = Arc::new(DefaultTableSource::new(initial_table.clone()));
281+
let target = DefaultTableSource::wrap(initial_table.clone());
282282
// Create and register the source table with the provided schema and inserted data
283283
let source_table = Arc::new(MemTable::try_new(schema.clone(), inserted_data)?);
284284
session_ctx.register_table("source", source_table.clone())?;
285285
// Convert the source table into a provider so that it can be used in a query
286-
let source = provider_as_source(source_table);
286+
let source = DefaultTableSource::wrap(source_table);
287287
// Create a table scan logical plan to read from the source table
288288
let scan_plan = LogicalPlanBuilder::scan("source", source, None)?.build()?;
289289
// Create an insert plan to insert the source data into the initial table

datafusion/core/src/datasource/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub mod provider;
3030
mod view_test;
3131

3232
// backwards compatibility
33-
pub use self::default_table_source::{provider_as_source, DefaultTableSource};
33+
pub use self::default_table_source::DefaultTableSource;
3434
pub use self::memory::MemTable;
3535
pub use self::view::ViewTable;
3636
pub use crate::catalog::TableProvider;

datafusion/core/src/execution/context/mod.rs

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use crate::{
3333
datasource::listing::{
3434
ListingOptions, ListingTable, ListingTableConfig, ListingTableUrl,
3535
},
36-
datasource::{provider_as_source, MemTable, ViewTable},
36+
datasource::{MemTable, ViewTable},
3737
error::{DataFusionError, Result},
3838
execution::{
3939
options::ArrowReadOptions,
@@ -58,6 +58,7 @@ pub use crate::execution::session_state::SessionState;
5858

5959
use arrow::datatypes::{Schema, SchemaRef};
6060
use arrow::record_batch::RecordBatch;
61+
use datafusion_catalog::default_table_source::DefaultTableSource;
6162
use datafusion_catalog::memory::MemorySchemaProvider;
6263
use datafusion_catalog::MemoryCatalogProvider;
6364
use datafusion_catalog::{
@@ -1387,8 +1388,12 @@ impl SessionContext {
13871388
pub fn read_table(&self, provider: Arc<dyn TableProvider>) -> Result<DataFrame> {
13881389
Ok(DataFrame::new(
13891390
self.state(),
1390-
LogicalPlanBuilder::scan(UNNAMED_TABLE, provider_as_source(provider), None)?
1391-
.build()?,
1391+
LogicalPlanBuilder::scan(
1392+
UNNAMED_TABLE,
1393+
DefaultTableSource::wrap(provider),
1394+
None,
1395+
)?
1396+
.build()?,
13921397
))
13931398
}
13941399

@@ -1399,7 +1404,7 @@ impl SessionContext {
13991404
self.state(),
14001405
LogicalPlanBuilder::scan(
14011406
UNNAMED_TABLE,
1402-
provider_as_source(Arc::new(provider)),
1407+
DefaultTableSource::wrap(Arc::new(provider)),
14031408
None,
14041409
)?
14051410
.build()?,
@@ -1422,7 +1427,7 @@ impl SessionContext {
14221427
self.state(),
14231428
LogicalPlanBuilder::scan(
14241429
UNNAMED_TABLE,
1425-
provider_as_source(Arc::new(provider)),
1430+
DefaultTableSource::wrap(Arc::new(provider)),
14261431
None,
14271432
)?
14281433
.build()?,
@@ -1586,7 +1591,7 @@ impl SessionContext {
15861591
let provider = self.table_provider(table_ref.clone()).await?;
15871592
let plan = LogicalPlanBuilder::scan(
15881593
table_ref,
1589-
provider_as_source(Arc::clone(&provider)),
1594+
DefaultTableSource::wrap(Arc::clone(&provider)),
15901595
None,
15911596
)?
15921597
.build()?;

datafusion/core/src/execution/session_state.rs

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ use std::sync::Arc;
2626
use crate::catalog::{CatalogProviderList, SchemaProvider, TableProviderFactory};
2727
use crate::datasource::cte_worktable::CteWorkTable;
2828
use crate::datasource::file_format::{format_as_file_type, FileFormatFactory};
29-
use crate::datasource::provider_as_source;
3029
use crate::execution::context::{EmptySerializerRegistry, FunctionFactory, QueryPlanner};
3130
use crate::execution::SessionStateDefaults;
3231
use crate::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner};
32+
use datafusion_catalog::default_table_source::DefaultTableSource;
3333
use datafusion_catalog::information_schema::{
3434
InformationSchemaProvider, INFORMATION_SCHEMA,
3535
};
@@ -109,11 +109,11 @@ use uuid::Uuid;
109109
/// # #[tokio::main]
110110
/// # async fn main() -> Result<()> {
111111
/// let state = SessionStateBuilder::new()
112-
/// .with_config(SessionConfig::new())
112+
/// .with_config(SessionConfig::new())
113113
/// .with_runtime_env(Arc::new(RuntimeEnv::default()))
114114
/// .with_default_features()
115115
/// .build();
116-
/// Ok(())
116+
/// Ok(())
117117
/// # }
118118
/// ```
119119
///
@@ -475,7 +475,7 @@ impl SessionState {
475475
let resolved = v.key();
476476
if let Ok(schema) = self.schema_for_ref(resolved.clone()) {
477477
if let Some(table) = schema.table(&resolved.table).await? {
478-
v.insert(provider_as_source(table));
478+
v.insert(DefaultTableSource::wrap(table));
479479
}
480480
}
481481
}
@@ -1300,7 +1300,7 @@ impl SessionStateBuilder {
13001300
/// let url = Url::try_from("file://").unwrap();
13011301
/// let object_store = object_store::local::LocalFileSystem::new();
13021302
/// let state = SessionStateBuilder::new()
1303-
/// .with_config(SessionConfig::new())
1303+
/// .with_config(SessionConfig::new())
13041304
/// .with_object_store(&url, Arc::new(object_store))
13051305
/// .with_default_features()
13061306
/// .build();
@@ -1684,7 +1684,7 @@ impl ContextProvider for SessionContextProvider<'_> {
16841684
.collect::<datafusion_common::Result<Vec<_>>>()?;
16851685
let provider = tbl_func.create_table_provider(&args)?;
16861686

1687-
Ok(provider_as_source(provider))
1687+
Ok(DefaultTableSource::wrap(provider))
16881688
}
16891689

16901690
/// Create a new CTE work table for a recursive CTE logical plan
@@ -1696,7 +1696,7 @@ impl ContextProvider for SessionContextProvider<'_> {
16961696
schema: SchemaRef,
16971697
) -> datafusion_common::Result<Arc<dyn TableSource>> {
16981698
let table = Arc::new(CteWorkTable::new(name, schema));
1699-
Ok(provider_as_source(table))
1699+
Ok(DefaultTableSource::wrap(table))
17001700
}
17011701

17021702
fn get_function_meta(&self, name: &str) -> Option<Arc<ScalarUDF>> {
@@ -1974,7 +1974,6 @@ mod tests {
19741974
use crate::common::assert_contains;
19751975
use crate::config::ConfigOptions;
19761976
use crate::datasource::empty::EmptyTable;
1977-
use crate::datasource::provider_as_source;
19781977
use crate::datasource::MemTable;
19791978
use crate::execution::context::SessionState;
19801979
use crate::logical_expr::planner::ExprPlanner;
@@ -1984,6 +1983,7 @@ mod tests {
19841983
use crate::sql::{ResolvedTableReference, TableReference};
19851984
use arrow::array::{ArrayRef, Int32Array, RecordBatch, StringArray};
19861985
use arrow::datatypes::{DataType, Field, Schema};
1986+
use datafusion_catalog::default_table_source::DefaultTableSource;
19871987
use datafusion_catalog::MemoryCatalogProviderList;
19881988
use datafusion_common::DFSchema;
19891989
use datafusion_common::Result;
@@ -2160,7 +2160,9 @@ mod tests {
21602160
) -> Result<Arc<dyn ExecutionPlan>> {
21612161
let mut context_provider = MyContextProvider::new().with_table(
21622162
"t",
2163-
provider_as_source(Arc::new(EmptyTable::new(Schema::empty().into()))),
2163+
DefaultTableSource::wrap(Arc::new(EmptyTable::new(
2164+
Schema::empty().into(),
2165+
))),
21642166
);
21652167
if with_expr_planners {
21662168
context_provider = context_provider.with_expr_planners();

datafusion/core/src/physical_planner.rs

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,6 @@ use crate::schema_equivalence::schema_satisfied_by;
6060
use arrow::array::{builder::StringBuilder, RecordBatch};
6161
use arrow::compute::SortOptions;
6262
use arrow::datatypes::{Schema, SchemaRef};
63-
use datafusion_catalog::default_table_source::source_as_provider;
6463
use datafusion_common::display::ToStringifiedPlan;
6564
use datafusion_common::tree_node::{
6665
Transformed, TransformedResult, TreeNode, TreeNodeRecursion, TreeNodeVisitor,
@@ -455,7 +454,15 @@ impl DefaultPhysicalPlanner {
455454
fetch,
456455
..
457456
}) => {
458-
let source = source_as_provider(source)?;
457+
let Some(source) = source.as_any().downcast_ref::<DefaultTableSource>()
458+
else {
459+
return Err(DataFusionError::Plan(
460+
"TableSource can only be used for logical planning".to_string(),
461+
));
462+
};
463+
464+
let source = Arc::clone(&source.table_provider);
465+
459466
// Remove all qualifiers from the scan as the provider
460467
// doesn't know (nor should care) how the relation was
461468
// referred to in the query

0 commit comments

Comments
 (0)