Skip to content

Commit

Permalink
fix: projection_push_down don't consider VarProvider in columns. (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwener authored May 8, 2023
1 parent fc5d67a commit 7760191
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 5 deletions.
41 changes: 40 additions & 1 deletion datafusion/core/tests/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,15 @@ use datafusion::from_slice::FromSlice;
use std::sync::Arc;

use datafusion::dataframe::DataFrame;
use datafusion::datasource::MemTable;
use datafusion::error::Result;
use datafusion::execution::context::SessionContext;
use datafusion::prelude::JoinType;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use datafusion::test_util::parquet_test_data;
use datafusion::{assert_batches_eq, assert_batches_sorted_eq};
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, ScalarValue};
use datafusion_execution::config::SessionConfig;
use datafusion_expr::expr::{GroupingSet, Sort};
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::Expr::Wildcard;
Expand All @@ -43,6 +45,7 @@ use datafusion_expr::{
sum, AggregateFunction, Expr, ExprSchemable, WindowFrame, WindowFrameBound,
WindowFrameUnits, WindowFunction,
};
use datafusion_physical_expr::var_provider::{VarProvider, VarType};

#[tokio::test]
async fn test_count_wildcard_on_sort() -> Result<()> {
Expand Down Expand Up @@ -1230,3 +1233,39 @@ pub async fn register_alltypes_tiny_pages_parquet(ctx: &SessionContext) -> Resul
.await?;
Ok(())
}
#[derive(Debug)]
struct HardcodedIntProvider {}

impl VarProvider for HardcodedIntProvider {
fn get_value(&self, _var_names: Vec<String>) -> Result<ScalarValue, DataFusionError> {
Ok(ScalarValue::Int64(Some(1234)))
}

fn get_type(&self, _: &[String]) -> Option<DataType> {
Some(DataType::Int64)
}
}

#[tokio::test]
async fn use_var_provider() -> Result<()> {
let schema = Arc::new(Schema::new(vec![
Field::new("foo", DataType::Int64, false),
Field::new("bar", DataType::Int64, false),
]));

let mem_table = Arc::new(MemTable::try_new(schema, vec![])?);

let config = SessionConfig::new()
.with_target_partitions(4)
.set_bool("datafusion.optimizer.skip_failed_rules", false);
let ctx = SessionContext::with_config(config);

ctx.register_table("csv_table", mem_table)?;
ctx.register_variable(VarType::UserDefined, Arc::new(HardcodedIntProvider {}));

let dataframe = ctx
.sql("SELECT foo FROM csv_table WHERE bar > @var")
.await?;
dataframe.collect().await?;
Ok(())
}
6 changes: 2 additions & 4 deletions datafusion/expr/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,13 +270,11 @@ pub fn expr_to_columns(expr: &Expr, accum: &mut HashSet<Column>) -> Result<()> {
Expr::Column(qc) => {
accum.insert(qc.clone());
}
Expr::ScalarVariable(_, var_names) => {
accum.insert(Column::from_name(var_names.join(".")));
}
// Use explicit pattern match instead of a default
// implementation, so that in the future if someone adds
// new Expr types, they will check here as well
Expr::Alias(_, _)
Expr::ScalarVariable(_, _)
| Expr::Alias(_, _)
| Expr::Literal(_)
| Expr::BinaryExpr { .. }
| Expr::Like { .. }
Expand Down

0 comments on commit 7760191

Please sign in to comment.