Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 68 additions & 4 deletions datafusion/src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,22 @@

//! DataFrame API for building and executing query plans.

use crate::arrow::datatypes::Schema;
use crate::arrow::datatypes::SchemaRef;
use crate::arrow::record_batch::RecordBatch;
use crate::datasource::TableProvider;
use crate::datasource::TableType;
use crate::error::Result;
use crate::execution::dataframe_impl::DataFrameImpl;
use crate::logical_plan::{
DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning,
};
use std::sync::Arc;

use crate::physical_plan::ExecutionPlan;
use crate::physical_plan::SendableRecordBatchStream;
use crate::scalar::ScalarValue;
use async_trait::async_trait;
use std::any::Any;
use std::sync::Arc;

/// DataFrame represents a logical set of rows with the same named columns.
/// Similar to a [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) or
Expand Down Expand Up @@ -53,7 +60,7 @@ use async_trait::async_trait;
/// # }
/// ```
#[async_trait]
pub trait DataFrame: Send + Sync {
pub trait DataFrame: TableProvider + Send + Sync {
/// Filter the DataFrame by column. Returns a new DataFrame only containing the
/// specified columns.
///
Expand Down Expand Up @@ -328,7 +335,7 @@ pub trait DataFrame: Send + Sync {
/// where each column has a name, data type, and nullability attribute.

/// ```
/// # use datafusion::prelude::*;
/// # use datafusion::prelude::{CsvReadOptions, ExecutionContext, DataFrame};
/// # use datafusion::error::Result;
/// # #[tokio::main]
/// # async fn main() -> Result<()> {
Expand Down Expand Up @@ -406,3 +413,60 @@ pub trait DataFrame: Send + Sync {
/// ```
fn except(&self, dataframe: Arc<dyn DataFrame>) -> Result<Arc<dyn DataFrame>>;
}

#[async_trait]
impl<D> TableProvider for D
where
D: DataFrame + 'static,
{
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
let schema: Schema = self.to_logical_plan().schema().as_ref().into();
Arc::new(schema)
}

fn table_type(&self) -> TableType {
TableType::View
}

async fn scan(
&self,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let plan = self.to_logical_plan();
let expr = projection
.as_ref()
// construct projections
.map_or_else(
|| Ok(Arc::new(DataFrameImpl::new(Default::default(), &plan)) as Arc<_>),
|projection| {
let schema = TableProvider::schema(self).project(projection)?;
let names = schema
.fields()
.iter()
.map(|field| field.name().as_str())
.collect::<Vec<_>>();
self.select_columns(names.as_slice())
},
)?
// add predicates, otherwise use `true` as the predicate
.filter(filters.iter().cloned().fold(
Expr::Literal(ScalarValue::Boolean(Some(true))),
|acc, new| acc.and(new),
))?;
// add a limit if given
DataFrameImpl::new(
Default::default(),
&limit
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))?
.to_logical_plan(),
)
.create_physical_plan()
.await
}
}
4 changes: 3 additions & 1 deletion datafusion/src/execution/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,9 @@ mod tests {
let tmp_dir = TempDir::new()?;
let ctx = create_ctx(&tmp_dir, 1).await?;

let schema: Schema = ctx.table("test").unwrap().schema().clone().into();
let schema: Schema = DataFrame::schema(&*ctx.table("test").unwrap())
.clone()
.into();
assert!(!schema.field_with_name("c1")?.is_nullable());

let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)?
Expand Down
72 changes: 12 additions & 60 deletions datafusion/src/execution/dataframe_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,27 +17,21 @@

//! Implementation of DataFrame API.

use std::any::Any;
use std::sync::{Arc, Mutex};

use crate::arrow::datatypes::Schema;
use crate::arrow::datatypes::SchemaRef;
use crate::arrow::record_batch::RecordBatch;
use crate::error::Result;
use crate::execution::context::{ExecutionContext, ExecutionContextState};
use crate::logical_plan::{
col, DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, LogicalPlanBuilder,
Partitioning,
};
use crate::scalar::ScalarValue;
use crate::{
dataframe::*,
physical_plan::{collect, collect_partitioned},
};

use crate::arrow::util::pretty;
use crate::datasource::TableProvider;
use crate::datasource::TableType;
use crate::physical_plan::{
execute_stream, execute_stream_partitioned, ExecutionPlan, SendableRecordBatchStream,
};
Expand All @@ -60,67 +54,14 @@ impl DataFrameImpl {
}

/// Create a physical plan
async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
pub(crate) async fn create_physical_plan(&self) -> Result<Arc<dyn ExecutionPlan>> {
let state = self.ctx_state.lock().unwrap().clone();
let ctx = ExecutionContext::from(Arc::new(Mutex::new(state)));
let plan = ctx.optimize(&self.plan)?;
ctx.create_physical_plan(&plan).await
}
}

#[async_trait]
impl TableProvider for DataFrameImpl {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
let schema: Schema = self.plan.schema().as_ref().into();
Arc::new(schema)
}

fn table_type(&self) -> TableType {
TableType::View
}

async fn scan(
&self,
projection: &Option<Vec<usize>>,
filters: &[Expr],
limit: Option<usize>,
) -> Result<Arc<dyn ExecutionPlan>> {
let expr = projection
.as_ref()
// construct projections
.map_or_else(
|| Ok(Arc::new(Self::new(self.ctx_state.clone(), &self.plan)) as Arc<_>),
|projection| {
let schema = TableProvider::schema(self).project(projection)?;
let names = schema
.fields()
.iter()
.map(|field| field.name().as_str())
.collect::<Vec<_>>();
self.select_columns(names.as_slice())
},
)?
// add predicates, otherwise use `true` as the predicate
.filter(filters.iter().cloned().fold(
Expr::Literal(ScalarValue::Boolean(Some(true))),
|acc, new| acc.and(new),
))?;
// add a limit if given
Self::new(
self.ctx_state.clone(),
&limit
.map_or_else(|| Ok(expr.clone()), |n| expr.limit(n))?
.to_logical_plan(),
)
.create_physical_plan()
.await
}
}

#[async_trait]
impl DataFrame for DataFrameImpl {
/// Apply a projection based on a list of column names
Expand Down Expand Up @@ -602,6 +543,17 @@ mod tests {
);
Ok(())
}

#[tokio::test]
async fn register_dataframe() -> Result<()> {
let df = test_table().await?.select_columns(&["c1", "c12"])?;
let mut ctx = ExecutionContext::new();

// register a dataframe as a table
ctx.register_table("test_table", df)?;
Ok(())
}

/// Compare the formatted string representation of two plans for equality
fn assert_same_plan(plan1: &LogicalPlan, plan2: &LogicalPlan) {
assert_eq!(format!("{:?}", plan1), format!("{:?}", plan2));
Expand Down