Skip to content

Commit

Permalink
Pass SessionState to TableProvider::scan
Browse files Browse the repository at this point in the history
  • Loading branch information
tustvold committed May 30, 2022
1 parent 3a313c9 commit 4ae3b42
Show file tree
Hide file tree
Showing 16 changed files with 69 additions and 33 deletions.
4 changes: 2 additions & 2 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result<Vec<RecordB
if opt.mem_table {
println!("Loading table '{}' into memory", table);
let start = Instant::now();
let task_ctx = ctx.task_ctx();
let memtable =
MemTable::load(table_provider, Some(opt.partitions), task_ctx).await?;
MemTable::load(table_provider, Some(opt.partitions), &ctx.state())
.await?;
println!(
"Loaded table '{}' into memory in {} ms",
table,
Expand Down
3 changes: 2 additions & 1 deletion datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result;
use datafusion::execution::context::TaskContext;
use datafusion::execution::context::{SessionState, TaskContext};
use datafusion::logical_plan::{provider_as_source, Expr, LogicalPlanBuilder};
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::memory::MemoryStream;
Expand Down Expand Up @@ -175,6 +175,7 @@ impl TableProvider for CustomDataSource {

async fn scan(
&self,
_state: &SessionState,
projection: &Option<Vec<usize>>,
// filters and limit can be used here to inject some push-down operations if needed
_filters: &[Expr],
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/benches/sort_limit_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ fn create_context() -> Arc<Mutex<SessionContext>> {
let ctx = SessionContext::new();
ctx.state.write().config.target_partitions = 1;

let task_ctx = ctx.task_ctx();
let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), task_ctx)
let table_provider = Arc::new(csv.await);
let mem_table = MemTable::load(table_provider, Some(partitions), &ctx.state())
.await
.unwrap();
ctx.register_table("aggregate_test_100", Arc::new(mem_table))
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ impl DataFrame {
}
}

// TODO: This will introduce a ref cycle (#2659)
#[async_trait]
impl TableProvider for DataFrame {
fn as_any(&self) -> &dyn Any {
Expand All @@ -632,6 +633,7 @@ impl TableProvider for DataFrame {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use datafusion_expr::{TableProviderFilterPushDown, TableType};

use crate::arrow::datatypes::SchemaRef;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_plan::Expr;
use crate::physical_plan::ExecutionPlan;

Expand All @@ -47,6 +48,7 @@ pub trait TableProvider: Sync + Send {
/// parallelized or distributed.
async fn scan(
&self,
ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
// limit can be used to reduce the amount scanned
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/datasource/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use async_trait::async_trait;

use crate::datasource::{TableProvider, TableType};
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_plan::Expr;
use crate::physical_plan::project_schema;
use crate::physical_plan::{empty::EmptyExec, ExecutionPlan};
Expand Down Expand Up @@ -57,6 +58,7 @@ impl TableProvider for EmptyTable {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down
14 changes: 11 additions & 3 deletions datafusion/core/src/datasource/listing/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::datasource::{
use crate::logical_expr::TableProviderFilterPushDown;
use crate::{
error::{DataFusionError, Result},
execution::context::SessionState,
logical_plan::Expr,
physical_plan::{
empty::EmptyExec,
Expand Down Expand Up @@ -302,6 +303,7 @@ impl TableProvider for ListingTable {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
Expand Down Expand Up @@ -405,6 +407,7 @@ impl ListingTable {
#[cfg(test)]
mod tests {
use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION;
use crate::prelude::SessionContext;
use crate::{
datafusion_data_access::object_store::local::LocalFileSystem,
datasource::file_format::{avro::AvroFormat, parquet::ParquetFormat},
Expand All @@ -417,10 +420,12 @@ mod tests {

#[tokio::test]
async fn read_single_file() -> Result<()> {
let ctx = SessionContext::new();

let table = load_table("alltypes_plain.parquet").await?;
let projection = None;
let exec = table
.scan(&projection, &[], None)
.scan(&ctx.state(), &projection, &[], None)
.await
.expect("Scan table");

Expand All @@ -447,7 +452,9 @@ mod tests {
.with_listing_options(opt)
.with_schema(schema);
let table = ListingTable::try_new(config)?;
let exec = table.scan(&None, &[], None).await?;

let ctx = SessionContext::new();
let exec = table.scan(&ctx.state(), &None, &[], None).await?;
assert_eq!(exec.statistics().num_rows, Some(8));
assert_eq!(exec.statistics().total_byte_size, Some(671));

Expand Down Expand Up @@ -483,8 +490,9 @@ mod tests {
// this will filter out the only file in the store
let filter = Expr::not_eq(col("p1"), lit("v1"));

let ctx = SessionContext::new();
let scan = table
.scan(&None, &[filter], None)
.scan(&ctx.state(), &None, &[filter], None)
.await
.expect("Empty execution plan");

Expand Down
34 changes: 24 additions & 10 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use async_trait::async_trait;

use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::TaskContext;
use crate::execution::context::{SessionState, TaskContext};
use crate::logical_plan::Expr;
use crate::physical_plan::common;
use crate::physical_plan::memory::MemoryExec;
Expand Down Expand Up @@ -65,18 +65,18 @@ impl MemTable {
pub async fn load(
t: Arc<dyn TableProvider>,
output_partitions: Option<usize>,
context: Arc<TaskContext>,
ctx: &SessionState,
) -> Result<Self> {
let schema = t.schema();
let exec = t.scan(&None, &[], None).await?;
let exec = t.scan(ctx, &None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();

let tasks = (0..partition_count)
.map(|part_i| {
let context1 = context.clone();
let task = Arc::new(TaskContext::from(ctx));
let exec = exec.clone();
tokio::spawn(async move {
let stream = exec.execute(part_i, context1.clone())?;
let stream = exec.execute(part_i, task)?;
common::collect(stream).await
})
})
Expand All @@ -103,7 +103,8 @@ impl MemTable {
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
// execute this *output* partition and collect all batches
let mut stream = exec.execute(i, context.clone())?;
let task_ctx = Arc::new(TaskContext::from(ctx));
let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
Expand Down Expand Up @@ -133,6 +134,7 @@ impl TableProvider for MemTable {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down Expand Up @@ -180,7 +182,10 @@ mod tests {
let provider = MemTable::try_new(schema, vec![vec![batch]])?;

// scan with projection
let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?;
let exec = provider
.scan(&session_ctx.state(), &Some(vec![2, 1]), &[], None)
.await?;

let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
Expand Down Expand Up @@ -212,7 +217,9 @@ mod tests {

let provider = MemTable::try_new(schema, vec![vec![batch]])?;

let exec = provider.scan(&None, &[], None).await?;
let exec = provider
.scan(&session_ctx.state(), &None, &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
Expand All @@ -223,6 +230,8 @@ mod tests {

#[tokio::test]
async fn test_invalid_projection() -> Result<()> {
let session_ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Expand All @@ -242,7 +251,10 @@ mod tests {

let projection: Vec<usize> = vec![0, 4];

match provider.scan(&Some(projection), &[], None).await {
match provider
.scan(&session_ctx.state(), &Some(projection), &[], None)
.await
{
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
Expand Down Expand Up @@ -368,7 +380,9 @@ mod tests {
let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;

let exec = provider.scan(&None, &[], None).await?;
let exec = provider
.scan(&session_ctx.state(), &None, &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
Expand Down
11 changes: 5 additions & 6 deletions datafusion/core/src/datasource/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@ use async_trait::async_trait;

use crate::{
error::Result,
execution::context::SessionContext,
logical_plan::{Expr, LogicalPlan},
physical_plan::ExecutionPlan,
};

use crate::datasource::{TableProvider, TableType};
use crate::execution::context::SessionState;

/// An implementation of `TableProvider` that uses another logical plan.
pub struct ViewTable {
/// To create ExecutionPlan
context: SessionContext,
/// LogicalPlan of the view
logical_plan: LogicalPlan,
/// File fields + partition columns
Expand All @@ -44,11 +42,10 @@ pub struct ViewTable {
impl ViewTable {
/// Create new view that is executed at query runtime.
/// Takes a `LogicalPlan` as input.
pub fn try_new(context: SessionContext, logical_plan: LogicalPlan) -> Result<Self> {
pub fn try_new(logical_plan: LogicalPlan) -> Result<Self> {
let table_schema = logical_plan.schema().as_ref().to_owned().into();

let view = Self {
context,
logical_plan,
table_schema,
};
Expand All @@ -73,16 +70,18 @@ impl TableProvider for ViewTable {

async fn scan(
&self,
ctx: &SessionState,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.context.create_physical_plan(&self.logical_plan).await
ctx.create_physical_plan(&self.logical_plan).await
}
}

#[cfg(test)]
mod tests {
use crate::prelude::SessionContext;
use crate::{assert_batches_eq, execution::context::SessionConfig};

use super::*;
Expand Down
11 changes: 7 additions & 4 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,14 @@ impl SessionContext {
(true, Ok(_)) => {
self.deregister_table(name.as_str())?;
let plan = self.optimize(&input)?;
let table =
Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
let table = Arc::new(ViewTable::try_new(plan.clone())?);

self.register_table(name.as_str(), table)?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
(_, Err(_)) => {
let plan = self.optimize(&input)?;
let table =
Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
let table = Arc::new(ViewTable::try_new(plan.clone())?);

self.register_table(name.as_str(), table)?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
Expand Down Expand Up @@ -931,6 +929,11 @@ impl SessionContext {
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}

/// Get a copy of the [`SessionState`] of this [`SessionContext`]
pub fn state(&self) -> SessionState {
self.state.read().clone()
}
}

impl FunctionRegistry for SessionContext {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl DefaultPhysicalPlanner {
// referred to in the query
let filters = unnormalize_cols(filters.iter().cloned());
let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
source.scan(projection, &unaliased, *limit).await
source.scan(session_state, projection, &unaliased, *limit).await
}
LogicalPlan::Values(Values {
values,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/tests/custom_sources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::{
};
use datafusion::{error::Result, physical_plan::DisplayFormatType};

use datafusion::execution::context::{SessionContext, TaskContext};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::logical_plan::{
col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE,
};
Expand Down Expand Up @@ -201,6 +201,7 @@ impl TableProvider for CustomTableProvider {

async fn scan(
&self,
_state: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::datasource::datasource::{TableProvider, TableType};
use datafusion::error::Result;
use datafusion::execution::context::{SessionContext, TaskContext};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::common::SizedRecordBatchStream;
use datafusion::physical_plan::expressions::PhysicalSortExpr;
Expand Down Expand Up @@ -138,6 +138,7 @@ impl TableProvider for CustomProvider {

async fn scan(
&self,
_state: &SessionState,
_: &Option<Vec<usize>>,
filters: &[Expr],
_: Option<usize>,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/tests/sql/information_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use async_trait::async_trait;
use datafusion::execution::context::SessionState;
use datafusion::{
catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
Expand Down Expand Up @@ -175,6 +176,7 @@ async fn information_schema_tables_table_types() {

async fn scan(
&self,
_ctx: &SessionState,
_: &Option<Vec<usize>>,
_: &[Expr],
_: Option<usize>,
Expand Down
Loading

0 comments on commit 4ae3b42

Please sign in to comment.