Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass SessionState to TableProvider::scan #2660

Merged
merged 3 commits into from
May 31, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -178,9 +178,9 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result<Vec<RecordB
if opt.mem_table {
println!("Loading table '{}' into memory", table);
let start = Instant::now();
let task_ctx = ctx.task_ctx();
let memtable =
MemTable::load(table_provider, Some(opt.partitions), task_ctx).await?;
MemTable::load(table_provider, Some(opt.partitions), &ctx.state())
.await?;
println!(
"Loaded table '{}' into memory in {} ms",
table,
Expand Down
3 changes: 2 additions & 1 deletion datafusion-examples/examples/custom_datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use datafusion::arrow::record_batch::RecordBatch;
use datafusion::dataframe::DataFrame;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result;
use datafusion::execution::context::TaskContext;
use datafusion::execution::context::{SessionState, TaskContext};
use datafusion::logical_plan::{provider_as_source, Expr, LogicalPlanBuilder};
use datafusion::physical_plan::expressions::PhysicalSortExpr;
use datafusion::physical_plan::memory::MemoryStream;
Expand Down Expand Up @@ -175,6 +175,7 @@ impl TableProvider for CustomDataSource {

async fn scan(
&self,
_state: &SessionState,
projection: &Option<Vec<usize>>,
// filters and limit can be used here to inject some push-down operations if needed
_filters: &[Expr],
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/benches/sort_limit_query_sql.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,8 +89,8 @@ fn create_context() -> Arc<Mutex<SessionContext>> {
let ctx = SessionContext::new();
ctx.state.write().config.target_partitions = 1;

let task_ctx = ctx.task_ctx();
let mem_table = MemTable::load(Arc::new(csv.await), Some(partitions), task_ctx)
let table_provider = Arc::new(csv.await);
let mem_table = MemTable::load(table_provider, Some(partitions), &ctx.state())
.await
.unwrap();
ctx.register_table("aggregate_test_100", Arc::new(mem_table))
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -615,6 +615,7 @@ impl DataFrame {
}
}

// TODO: This will introduce a ref cycle (#2659)
#[async_trait]
impl TableProvider for DataFrame {
fn as_any(&self) -> &dyn Any {
Expand All @@ -632,6 +633,7 @@ impl TableProvider for DataFrame {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/datasource/datasource.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use datafusion_expr::{TableProviderFilterPushDown, TableType};

use crate::arrow::datatypes::SchemaRef;
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_plan::Expr;
use crate::physical_plan::ExecutionPlan;

Expand All @@ -47,6 +48,7 @@ pub trait TableProvider: Sync + Send {
/// parallelized or distributed.
async fn scan(
&self,
ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
// limit can be used to reduce the amount scanned
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/src/datasource/empty.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ use async_trait::async_trait;

use crate::datasource::{TableProvider, TableType};
use crate::error::Result;
use crate::execution::context::SessionState;
use crate::logical_plan::Expr;
use crate::physical_plan::project_schema;
use crate::physical_plan::{empty::EmptyExec, ExecutionPlan};
Expand Down Expand Up @@ -57,6 +58,7 @@ impl TableProvider for EmptyTable {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down
14 changes: 11 additions & 3 deletions datafusion/core/src/datasource/listing/table.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use crate::datasource::{
use crate::logical_expr::TableProviderFilterPushDown;
use crate::{
error::{DataFusionError, Result},
execution::context::SessionState,
logical_plan::Expr,
physical_plan::{
empty::EmptyExec,
Expand Down Expand Up @@ -302,6 +303,7 @@ impl TableProvider for ListingTable {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
Expand Down Expand Up @@ -405,6 +407,7 @@ impl ListingTable {
#[cfg(test)]
mod tests {
use crate::datasource::file_format::avro::DEFAULT_AVRO_EXTENSION;
use crate::prelude::SessionContext;
use crate::{
datafusion_data_access::object_store::local::LocalFileSystem,
datasource::file_format::{avro::AvroFormat, parquet::ParquetFormat},
Expand All @@ -417,10 +420,12 @@ mod tests {

#[tokio::test]
async fn read_single_file() -> Result<()> {
let ctx = SessionContext::new();

let table = load_table("alltypes_plain.parquet").await?;
let projection = None;
let exec = table
.scan(&projection, &[], None)
.scan(&ctx.state(), &projection, &[], None)
.await
.expect("Scan table");

Expand All @@ -447,7 +452,9 @@ mod tests {
.with_listing_options(opt)
.with_schema(schema);
let table = ListingTable::try_new(config)?;
let exec = table.scan(&None, &[], None).await?;

let ctx = SessionContext::new();
let exec = table.scan(&ctx.state(), &None, &[], None).await?;
assert_eq!(exec.statistics().num_rows, Some(8));
assert_eq!(exec.statistics().total_byte_size, Some(671));

Expand Down Expand Up @@ -483,8 +490,9 @@ mod tests {
// this will filter out the only file in the store
let filter = Expr::not_eq(col("p1"), lit("v1"));

let ctx = SessionContext::new();
let scan = table
.scan(&None, &[filter], None)
.scan(&ctx.state(), &None, &[filter], None)
.await
.expect("Empty execution plan");

Expand Down
34 changes: 24 additions & 10 deletions datafusion/core/src/datasource/memory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ use async_trait::async_trait;

use crate::datasource::{TableProvider, TableType};
use crate::error::{DataFusionError, Result};
use crate::execution::context::TaskContext;
use crate::execution::context::{SessionState, TaskContext};
use crate::logical_plan::Expr;
use crate::physical_plan::common;
use crate::physical_plan::memory::MemoryExec;
Expand Down Expand Up @@ -65,18 +65,18 @@ impl MemTable {
pub async fn load(
t: Arc<dyn TableProvider>,
output_partitions: Option<usize>,
context: Arc<TaskContext>,
ctx: &SessionState,
) -> Result<Self> {
let schema = t.schema();
let exec = t.scan(&None, &[], None).await?;
let exec = t.scan(ctx, &None, &[], None).await?;
let partition_count = exec.output_partitioning().partition_count();

let tasks = (0..partition_count)
.map(|part_i| {
let context1 = context.clone();
let task = Arc::new(TaskContext::from(ctx));
let exec = exec.clone();
tokio::spawn(async move {
let stream = exec.execute(part_i, context1.clone())?;
let stream = exec.execute(part_i, task)?;
common::collect(stream).await
})
})
Expand All @@ -103,7 +103,8 @@ impl MemTable {
let mut output_partitions = vec![];
for i in 0..exec.output_partitioning().partition_count() {
// execute this *output* partition and collect all batches
let mut stream = exec.execute(i, context.clone())?;
let task_ctx = Arc::new(TaskContext::from(ctx));
let mut stream = exec.execute(i, task_ctx)?;
let mut batches = vec![];
while let Some(result) = stream.next().await {
batches.push(result?);
Expand Down Expand Up @@ -133,6 +134,7 @@ impl TableProvider for MemTable {

async fn scan(
&self,
_ctx: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down Expand Up @@ -180,7 +182,10 @@ mod tests {
let provider = MemTable::try_new(schema, vec![vec![batch]])?;

// scan with projection
let exec = provider.scan(&Some(vec![2, 1]), &[], None).await?;
let exec = provider
.scan(&session_ctx.state(), &Some(vec![2, 1]), &[], None)
.await?;

let mut it = exec.execute(0, task_ctx)?;
let batch2 = it.next().await.unwrap()?;
assert_eq!(2, batch2.schema().fields().len());
Expand Down Expand Up @@ -212,7 +217,9 @@ mod tests {

let provider = MemTable::try_new(schema, vec![vec![batch]])?;

let exec = provider.scan(&None, &[], None).await?;
let exec = provider
.scan(&session_ctx.state(), &None, &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
Expand All @@ -223,6 +230,8 @@ mod tests {

#[tokio::test]
async fn test_invalid_projection() -> Result<()> {
let session_ctx = SessionContext::new();

let schema = Arc::new(Schema::new(vec![
Field::new("a", DataType::Int32, false),
Field::new("b", DataType::Int32, false),
Expand All @@ -242,7 +251,10 @@ mod tests {

let projection: Vec<usize> = vec![0, 4];

match provider.scan(&Some(projection), &[], None).await {
match provider
.scan(&session_ctx.state(), &Some(projection), &[], None)
.await
{
Err(DataFusionError::ArrowError(ArrowError::SchemaError(e))) => {
assert_eq!(
"\"project index 4 out of bounds, max field 3\"",
Expand Down Expand Up @@ -368,7 +380,9 @@ mod tests {
let provider =
MemTable::try_new(Arc::new(merged_schema), vec![vec![batch1, batch2]])?;

let exec = provider.scan(&None, &[], None).await?;
let exec = provider
.scan(&session_ctx.state(), &None, &[], None)
.await?;
let mut it = exec.execute(0, task_ctx)?;
let batch1 = it.next().await.unwrap()?;
assert_eq!(3, batch1.schema().fields().len());
Expand Down
11 changes: 5 additions & 6 deletions datafusion/core/src/datasource/view.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,17 +24,15 @@ use async_trait::async_trait;

use crate::{
error::Result,
execution::context::SessionContext,
logical_plan::{Expr, LogicalPlan},
physical_plan::ExecutionPlan,
};

use crate::datasource::{TableProvider, TableType};
use crate::execution::context::SessionState;

/// An implementation of `TableProvider` that uses another logical plan.
pub struct ViewTable {
/// To create ExecutionPlan
context: SessionContext,
/// LogicalPlan of the view
logical_plan: LogicalPlan,
/// File fields + partition columns
Expand All @@ -44,11 +42,10 @@ pub struct ViewTable {
impl ViewTable {
/// Create new view that is executed at query runtime.
/// Takes a `LogicalPlan` as input.
pub fn try_new(context: SessionContext, logical_plan: LogicalPlan) -> Result<Self> {
pub fn try_new(logical_plan: LogicalPlan) -> Result<Self> {
let table_schema = logical_plan.schema().as_ref().to_owned().into();

let view = Self {
context,
logical_plan,
table_schema,
};
Expand All @@ -73,16 +70,18 @@ impl TableProvider for ViewTable {

async fn scan(
&self,
ctx: &SessionState,
_projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
self.context.create_physical_plan(&self.logical_plan).await
ctx.create_physical_plan(&self.logical_plan).await
}
}

#[cfg(test)]
mod tests {
use crate::prelude::SessionContext;
use crate::{assert_batches_eq, execution::context::SessionConfig};

use super::*;
Expand Down
11 changes: 7 additions & 4 deletions datafusion/core/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -349,16 +349,14 @@ impl SessionContext {
(true, Ok(_)) => {
self.deregister_table(name.as_str())?;
let plan = self.optimize(&input)?;
let table =
Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
let table = Arc::new(ViewTable::try_new(plan.clone())?);

self.register_table(name.as_str(), table)?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
}
(_, Err(_)) => {
let plan = self.optimize(&input)?;
let table =
Arc::new(ViewTable::try_new(self.clone(), plan.clone())?);
let table = Arc::new(ViewTable::try_new(plan.clone())?);

self.register_table(name.as_str(), table)?;
Ok(Arc::new(DataFrame::new(self.state.clone(), &plan)))
Expand Down Expand Up @@ -931,6 +929,11 @@ impl SessionContext {
pub fn task_ctx(&self) -> Arc<TaskContext> {
Arc::new(TaskContext::from(self))
}

/// Get a copy of the [`SessionState`] of this [`SessionContext`]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a massive fan of this, but it appears to be what we do already for every query and when constructing a TaskContext, so it can't be that bad..........

pub fn state(&self) -> SessionState {
self.state.read().clone()
}
}

impl FunctionRegistry for SessionContext {
Expand Down
2 changes: 1 addition & 1 deletion datafusion/core/src/physical_plan/planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,7 @@ impl DefaultPhysicalPlanner {
// referred to in the query
let filters = unnormalize_cols(filters.iter().cloned());
let unaliased: Vec<Expr> = filters.into_iter().map(unalias).collect();
source.scan(projection, &unaliased, *limit).await
source.scan(session_state, projection, &unaliased, *limit).await
}
LogicalPlan::Values(Values {
values,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/tests/custom_sources.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use datafusion::{
};
use datafusion::{error::Result, physical_plan::DisplayFormatType};

use datafusion::execution::context::{SessionContext, TaskContext};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::logical_plan::{
col, Expr, LogicalPlan, LogicalPlanBuilder, TableScan, UNNAMED_TABLE,
};
Expand Down Expand Up @@ -201,6 +201,7 @@ impl TableProvider for CustomTableProvider {

async fn scan(
&self,
_state: &SessionState,
projection: &Option<Vec<usize>>,
_filters: &[Expr],
_limit: Option<usize>,
Expand Down
3 changes: 2 additions & 1 deletion datafusion/core/tests/provider_filter_pushdown.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use arrow::record_batch::RecordBatch;
use async_trait::async_trait;
use datafusion::datasource::datasource::{TableProvider, TableType};
use datafusion::error::Result;
use datafusion::execution::context::{SessionContext, TaskContext};
use datafusion::execution::context::{SessionContext, SessionState, TaskContext};
use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::common::SizedRecordBatchStream;
use datafusion::physical_plan::expressions::PhysicalSortExpr;
Expand Down Expand Up @@ -138,6 +138,7 @@ impl TableProvider for CustomProvider {

async fn scan(
&self,
_state: &SessionState,
_: &Option<Vec<usize>>,
filters: &[Expr],
_: Option<usize>,
Expand Down
2 changes: 2 additions & 0 deletions datafusion/core/tests/sql/information_schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
// under the License.

use async_trait::async_trait;
use datafusion::execution::context::SessionState;
use datafusion::{
catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
Expand Down Expand Up @@ -175,6 +176,7 @@ async fn information_schema_tables_table_types() {

async fn scan(
&self,
_ctx: &SessionState,
_: &Option<Vec<usize>>,
_: &[Expr],
_: Option<usize>,
Expand Down
Loading