From 03fbf9fecad00f8d6eb3e72e72ba16252b28b1d6 Mon Sep 17 00:00:00 2001 From: Marco Neumann Date: Wed, 1 Mar 2023 17:24:55 +0100 Subject: [PATCH] refactor: ParquetExec logical expr. => phys. expr. (#5419) * feat: `get_phys_expr_columns` util * feat: add `reassign_predicate_columns` util * feat: `PhysicalExprRef` type alias * refactor: `ParquetExec` logical expr. => phys. expr. Use `Arc` instead of `Expr` within `ParquetExec` and move lowering from logical to physical expression into plan lowering (e.g. `ListingTable`). This is in line w/ all other physical plan nodes (e.g. `FilterExpr`) and simplifies reasoning within physical optimizer but also allows correct passing of `ExecutionProps` into the conversion. Closes #4695. --- .../core/src/datasource/file_format/avro.rs | 4 +- .../core/src/datasource/file_format/csv.rs | 4 +- .../core/src/datasource/file_format/json.rs | 4 +- .../core/src/datasource/file_format/mod.rs | 6 +- .../src/datasource/file_format/parquet.rs | 14 +- .../core/src/datasource/listing/table.rs | 20 +- .../core/src/physical_optimizer/pruning.rs | 955 +++++++++++------- .../src/physical_plan/file_format/parquet.rs | 34 +- .../file_format/parquet/page_filter.rs | 21 +- .../file_format/parquet/row_filter.rs | 101 +- .../file_format/parquet/row_groups.rs | 68 +- datafusion/core/tests/parquet/page_pruning.rs | 11 +- datafusion/core/tests/row.rs | 2 +- datafusion/physical-expr/src/lib.rs | 2 +- datafusion/physical-expr/src/physical_expr.rs | 3 + datafusion/physical-expr/src/utils.rs | 78 ++ datafusion/proto/proto/datafusion.proto | 6 +- datafusion/proto/src/generated/pbjson.rs | 25 +- datafusion/proto/src/generated/prost.rs | 4 +- datafusion/proto/src/physical_plan/mod.rs | 36 +- parquet-test-utils/src/lib.rs | 6 +- 21 files changed, 907 insertions(+), 497 deletions(-) diff --git a/datafusion/core/src/datasource/file_format/avro.rs b/datafusion/core/src/datasource/file_format/avro.rs index 1b6d2b3bc6f2..422733308996 100644 --- a/datafusion/core/src/datasource/file_format/avro.rs +++ b/datafusion/core/src/datasource/file_format/avro.rs @@ -23,13 +23,13 @@ use std::sync::Arc; use arrow::datatypes::Schema; use arrow::{self, datatypes::SchemaRef}; use async_trait::async_trait; +use datafusion_physical_expr::PhysicalExpr; use object_store::{GetResult, ObjectMeta, ObjectStore}; use super::FileFormat; use crate::avro_to_arrow::read_avro_schema_from_reader; use crate::error::Result; use crate::execution::context::SessionState; -use crate::logical_expr::Expr; use crate::physical_plan::file_format::{AvroExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; @@ -82,7 +82,7 @@ impl FileFormat for AvroFormat { &self, _state: &SessionState, conf: FileScanConfig, - _filters: &[Expr], + _filters: Option<&Arc>, ) -> Result> { let exec = AvroExec::new(conf); Ok(Arc::new(exec)) diff --git a/datafusion/core/src/datasource/file_format/csv.rs b/datafusion/core/src/datasource/file_format/csv.rs index 85a9d186a7e6..fcab651e3514 100644 --- a/datafusion/core/src/datasource/file_format/csv.rs +++ b/datafusion/core/src/datasource/file_format/csv.rs @@ -29,6 +29,7 @@ use bytes::{Buf, Bytes}; use datafusion_common::DataFusionError; +use datafusion_physical_expr::PhysicalExpr; use futures::{pin_mut, Stream, StreamExt, TryStreamExt}; use object_store::{delimited::newline_delimited_stream, ObjectMeta, ObjectStore}; @@ -37,7 +38,6 @@ use crate::datasource::file_format::file_type::FileCompressionType; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::error::Result; use crate::execution::context::SessionState; -use crate::logical_expr::Expr; use crate::physical_plan::file_format::{CsvExec, FileScanConfig}; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; @@ -154,7 +154,7 @@ impl FileFormat for CsvFormat { &self, _state: &SessionState, conf: FileScanConfig, - _filters: &[Expr], + _filters: Option<&Arc>, ) -> Result> { let exec = CsvExec::new( conf, diff --git a/datafusion/core/src/datasource/file_format/json.rs b/datafusion/core/src/datasource/file_format/json.rs index 7b0b0e18db22..a66edab888bf 100644 --- a/datafusion/core/src/datasource/file_format/json.rs +++ b/datafusion/core/src/datasource/file_format/json.rs @@ -29,6 +29,7 @@ use arrow::json::reader::ValueIter; use async_trait::async_trait; use bytes::Buf; +use datafusion_physical_expr::PhysicalExpr; use object_store::{GetResult, ObjectMeta, ObjectStore}; use super::FileFormat; @@ -37,7 +38,6 @@ use crate::datasource::file_format::file_type::FileCompressionType; use crate::datasource::file_format::DEFAULT_SCHEMA_INFER_MAX_RECORD; use crate::error::Result; use crate::execution::context::SessionState; -use crate::logical_expr::Expr; use crate::physical_plan::file_format::NdJsonExec; use crate::physical_plan::ExecutionPlan; use crate::physical_plan::Statistics; @@ -143,7 +143,7 @@ impl FileFormat for JsonFormat { &self, _state: &SessionState, conf: FileScanConfig, - _filters: &[Expr], + _filters: Option<&Arc>, ) -> Result> { let exec = NdJsonExec::new(conf, self.file_compression_type.to_owned()); Ok(Arc::new(exec)) diff --git a/datafusion/core/src/datasource/file_format/mod.rs b/datafusion/core/src/datasource/file_format/mod.rs index 947327630d9c..52da7285e373 100644 --- a/datafusion/core/src/datasource/file_format/mod.rs +++ b/datafusion/core/src/datasource/file_format/mod.rs @@ -33,12 +33,12 @@ use std::sync::Arc; use crate::arrow::datatypes::SchemaRef; use crate::error::Result; -use crate::logical_expr::Expr; use crate::physical_plan::file_format::FileScanConfig; use crate::physical_plan::{ExecutionPlan, Statistics}; use crate::execution::context::SessionState; use async_trait::async_trait; +use datafusion_physical_expr::PhysicalExpr; use object_store::{ObjectMeta, ObjectStore}; /// This trait abstracts all the file format specific implementations @@ -84,7 +84,7 @@ pub trait FileFormat: Send + Sync + fmt::Debug { &self, state: &SessionState, conf: FileScanConfig, - filters: &[Expr], + filters: Option<&Arc>, ) -> Result>; } @@ -143,7 +143,7 @@ pub(crate) mod test_util { output_ordering: None, infinite_source: false, }, - &[], + None, ) .await?; Ok(exec) diff --git a/datafusion/core/src/datasource/file_format/parquet.rs b/datafusion/core/src/datasource/file_format/parquet.rs index 0a7a7cadc90a..53e94167d845 100644 --- a/datafusion/core/src/datasource/file_format/parquet.rs +++ b/datafusion/core/src/datasource/file_format/parquet.rs @@ -25,7 +25,7 @@ use arrow::datatypes::SchemaRef; use async_trait::async_trait; use bytes::{BufMut, BytesMut}; use datafusion_common::DataFusionError; -use datafusion_optimizer::utils::conjunction; +use datafusion_physical_expr::PhysicalExpr; use hashbrown::HashMap; use object_store::{ObjectMeta, ObjectStore}; use parquet::arrow::parquet_to_arrow_schema; @@ -44,7 +44,6 @@ use crate::config::ConfigOptions; use crate::datasource::{create_max_min_accs, get_col_stats}; use crate::error::Result; use crate::execution::context::SessionState; -use crate::logical_expr::Expr; use crate::physical_plan::expressions::{MaxAccumulator, MinAccumulator}; use crate::physical_plan::file_format::{ParquetExec, SchemaAdapter}; use crate::physical_plan::{Accumulator, ExecutionPlan, Statistics}; @@ -189,16 +188,15 @@ impl FileFormat for ParquetFormat { &self, state: &SessionState, conf: FileScanConfig, - filters: &[Expr], + filters: Option<&Arc>, ) -> Result> { // If enable pruning then combine the filters to build the predicate. // If disable pruning then set the predicate to None, thus readers // will not prune data based on the statistics. - let predicate = if self.enable_pruning(state.config_options()) { - conjunction(filters.to_vec()) - } else { - None - }; + let predicate = self + .enable_pruning(state.config_options()) + .then(|| filters.cloned()) + .flatten(); Ok(Arc::new(ParquetExec::new( conf, diff --git a/datafusion/core/src/datasource/listing/table.rs b/datafusion/core/src/datasource/listing/table.rs index 29e2259e44c5..f6d9c959eb1a 100644 --- a/datafusion/core/src/datasource/listing/table.rs +++ b/datafusion/core/src/datasource/listing/table.rs @@ -24,8 +24,10 @@ use arrow::compute::SortOptions; use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; use async_trait::async_trait; use dashmap::DashMap; +use datafusion_common::ToDFSchema; use datafusion_expr::expr::Sort; -use datafusion_physical_expr::PhysicalSortExpr; +use datafusion_optimizer::utils::conjunction; +use datafusion_physical_expr::{create_physical_expr, PhysicalSortExpr}; use futures::{future, stream, StreamExt, TryStreamExt}; use object_store::path::Path; use object_store::ObjectMeta; @@ -661,6 +663,20 @@ impl TableProvider for ListingTable { }) .collect::>>()?; + let filters = if let Some(expr) = conjunction(filters.to_vec()) { + // NOTE: Use the table schema (NOT file schema) here because `expr` may contain references to partition columns. + let table_df_schema = self.table_schema.as_ref().clone().to_dfschema()?; + let filters = create_physical_expr( + &expr, + &table_df_schema, + &self.table_schema, + state.execution_props(), + )?; + Some(filters) + } else { + None + }; + // create the execution plan self.options .format @@ -677,7 +693,7 @@ impl TableProvider for ListingTable { table_partition_cols, infinite_source: self.infinite_source, }, - filters, + filters.as_ref(), ) .await } diff --git a/datafusion/core/src/physical_optimizer/pruning.rs b/datafusion/core/src/physical_optimizer/pruning.rs index 03e376878740..fbf000148dcc 100644 --- a/datafusion/core/src/physical_optimizer/pruning.rs +++ b/datafusion/core/src/physical_optimizer/pruning.rs @@ -32,14 +32,13 @@ use std::collections::HashSet; use std::convert::TryFrom; use std::sync::Arc; -use crate::execution::context::ExecutionProps; -use crate::prelude::lit; use crate::{ common::{Column, DFSchema}, error::{DataFusionError, Result}, - logical_expr::{Expr, Operator}, + logical_expr::Operator, physical_plan::{ColumnarValue, PhysicalExpr}, }; +use arrow::compute::DEFAULT_CAST_OPTIONS; use arrow::record_batch::RecordBatchOptions; use arrow::{ array::{new_null_array, ArrayRef, BooleanArray}, @@ -47,11 +46,10 @@ use arrow::{ record_batch::RecordBatch, }; use datafusion_common::{downcast_value, ScalarValue}; -use datafusion_expr::expr::{BinaryExpr, Cast, TryCast}; -use datafusion_expr::expr_rewriter::rewrite_expr; -use datafusion_expr::{binary_expr, cast, try_cast, ExprSchemable}; -use datafusion_physical_expr::create_physical_expr; -use datafusion_physical_expr::expressions::Literal; +use datafusion_physical_expr::{expressions as phys_expr, PhysicalExprRef}; + +use datafusion_physical_expr::rewrite::{TreeNodeRewritable, TreeNodeRewriter}; +use datafusion_physical_expr::utils::get_phys_expr_columns; use log::trace; /// Interface to pass statistics information to [`PruningPredicate`] @@ -104,8 +102,8 @@ pub struct PruningPredicate { predicate_expr: Arc, /// The statistics required to evaluate this predicate required_columns: RequiredStatColumns, - /// Logical predicate from which this predicate expr is derived (required for serialization) - logical_expr: Expr, + /// Original physical predicate from which this predicate expr is derived (required for serialization) + orig_expr: Arc, } impl PruningPredicate { @@ -128,31 +126,16 @@ impl PruningPredicate { /// For example, the filter expression `(column / 2) = 4` becomes /// the pruning predicate /// `(column_min / 2) <= 4 && 4 <= (column_max / 2))` - pub fn try_new(expr: Expr, schema: SchemaRef) -> Result { + pub fn try_new(expr: Arc, schema: SchemaRef) -> Result { // build predicate expression once let mut required_columns = RequiredStatColumns::new(); - let logical_predicate_expr = + let predicate_expr = build_predicate_expression(&expr, schema.as_ref(), &mut required_columns); - let stat_fields = required_columns - .iter() - .map(|(_, _, f)| f.clone()) - .collect::>(); - let stat_schema = Schema::new(stat_fields); - let stat_dfschema = DFSchema::try_from(stat_schema.clone())?; - - // TODO allow these properties to be passed in - let execution_props = ExecutionProps::new(); - let predicate_expr = create_physical_expr( - &logical_predicate_expr, - &stat_dfschema, - &stat_schema, - &execution_props, - )?; Ok(Self { schema, predicate_expr, required_columns, - logical_expr: expr, + orig_expr: expr, }) } @@ -215,9 +198,9 @@ impl PruningPredicate { &self.schema } - /// Returns a reference to the logical expr used to construct this pruning predicate - pub fn logical_expr(&self) -> &Expr { - &self.logical_expr + /// Returns a reference to the physical expr used to construct this pruning predicate + pub fn orig_expr(&self) -> &Arc { + &self.orig_expr } /// Returns a reference to the predicate expr @@ -227,11 +210,7 @@ impl PruningPredicate { /// Returns true if this pruning predicate is "always true" (aka will not prune anything) pub fn allways_true(&self) -> bool { - self.predicate_expr - .as_any() - .downcast_ref::() - .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(true)))) - .unwrap_or_default() + is_always_true(&self.predicate_expr) } pub(crate) fn required_columns(&self) -> &RequiredStatColumns { @@ -239,6 +218,13 @@ impl PruningPredicate { } } +fn is_always_true(expr: &Arc) -> bool { + expr.as_any() + .downcast_ref::() + .map(|l| matches!(l.value(), ScalarValue::Boolean(Some(true)))) + .unwrap_or_default() +} + /// Records for which columns statistics are necessary to evaluate a /// pruning predicate. /// @@ -251,7 +237,7 @@ pub(crate) struct RequiredStatColumns { /// * Statistics type (e.g. Min or Max or Null_Count) /// * The field the statistics value should be placed in for /// pruning predicate evaluation - columns: Vec<(Column, StatisticsType, Field)>, + columns: Vec<(phys_expr::Column, StatisticsType, Field)>, } impl RequiredStatColumns { @@ -269,19 +255,22 @@ impl RequiredStatColumns { /// Returns an iterator over items in columns (see doc on /// `self.columns` for details) - pub(crate) fn iter(&self) -> impl Iterator { + pub(crate) fn iter( + &self, + ) -> impl Iterator { self.columns.iter() } - fn is_stat_column_missing( + fn find_stat_column( &self, - column: &Column, + column: &phys_expr::Column, statistics_type: StatisticsType, - ) -> bool { - !self - .columns + ) -> Option { + self.columns .iter() - .any(|(c, t, _f)| c == column && t == &statistics_type) + .enumerate() + .find(|(_i, (c, t, _f))| c == column && t == &statistics_type) + .map(|(i, (_c, _t, _f))| i) } /// Rewrites column_expr so that all appearances of column @@ -294,25 +283,27 @@ impl RequiredStatColumns { /// 5` with the appropriate entry noted in self.columns fn stat_column_expr( &mut self, - column: &Column, - column_expr: &Expr, + column: &phys_expr::Column, + column_expr: &Arc, field: &Field, stat_type: StatisticsType, suffix: &str, - ) -> Result { - let stat_column = Column { - relation: column.relation.clone(), - name: format!("{}_{}", column.flat_name(), suffix), + ) -> Result> { + let (idx, need_to_insert) = match self.find_stat_column(column, stat_type) { + Some(idx) => (idx, false), + None => (self.columns.len(), true), }; - let stat_field = Field::new( - stat_column.flat_name().as_str(), - field.data_type().clone(), - field.is_nullable(), - ); + let stat_column = + phys_expr::Column::new(&format!("{}_{}", column.name(), suffix), idx); - if self.is_stat_column_missing(column, stat_type) { - // only add statistics column if not previously added + // only add statistics column if not previously added + if need_to_insert { + let stat_field = Field::new( + stat_column.name(), + field.data_type().clone(), + field.is_nullable(), + ); self.columns.push((column.clone(), stat_type, stat_field)); } rewrite_column_expr(column_expr.clone(), column, &stat_column) @@ -321,30 +312,30 @@ impl RequiredStatColumns { /// rewrite col --> col_min fn min_column_expr( &mut self, - column: &Column, - column_expr: &Expr, + column: &phys_expr::Column, + column_expr: &Arc, field: &Field, - ) -> Result { + ) -> Result> { self.stat_column_expr(column, column_expr, field, StatisticsType::Min, "min") } /// rewrite col --> col_max fn max_column_expr( &mut self, - column: &Column, - column_expr: &Expr, + column: &phys_expr::Column, + column_expr: &Arc, field: &Field, - ) -> Result { + ) -> Result> { self.stat_column_expr(column, column_expr, field, StatisticsType::Max, "max") } /// rewrite col --> col_null_count fn null_count_column_expr( &mut self, - column: &Column, - column_expr: &Expr, + column: &phys_expr::Column, + column_expr: &Arc, field: &Field, - ) -> Result { + ) -> Result> { self.stat_column_expr( column, column_expr, @@ -355,8 +346,8 @@ impl RequiredStatColumns { } } -impl From> for RequiredStatColumns { - fn from(columns: Vec<(Column, StatisticsType, Field)>) -> Self { +impl From> for RequiredStatColumns { + fn from(columns: Vec<(phys_expr::Column, StatisticsType, Field)>) -> Self { Self { columns } } } @@ -394,14 +385,15 @@ fn build_statistics_record_batch( let mut arrays = Vec::::new(); // For each needed statistics column: for (column, statistics_type, stat_field) in required_columns.iter() { + let column = Column::from_qualified_name(column.name()); let data_type = stat_field.data_type(); let num_containers = statistics.num_containers(); let array = match statistics_type { - StatisticsType::Min => statistics.min_values(column), - StatisticsType::Max => statistics.max_values(column), - StatisticsType::NullCount => statistics.null_counts(column), + StatisticsType::Min => statistics.min_values(&column), + StatisticsType::Max => statistics.max_values(&column), + StatisticsType::NullCount => statistics.null_counts(&column), }; let array = array.unwrap_or_else(|| new_null_array(data_type, num_containers)); @@ -438,25 +430,25 @@ fn build_statistics_record_batch( } struct PruningExpressionBuilder<'a> { - column: Column, - column_expr: Expr, + column: phys_expr::Column, + column_expr: Arc, op: Operator, - scalar_expr: Expr, + scalar_expr: Arc, field: &'a Field, required_columns: &'a mut RequiredStatColumns, } impl<'a> PruningExpressionBuilder<'a> { fn try_new( - left: &'a Expr, - right: &'a Expr, + left: &'a Arc, + right: &'a Arc, op: Operator, schema: &'a Schema, required_columns: &'a mut RequiredStatColumns, ) -> Result { // find column name; input could be a more complicated expression - let left_columns = left.to_columns()?; - let right_columns = right.to_columns()?; + let left_columns = get_phys_expr_columns(left); + let right_columns = get_phys_expr_columns(right); let (column_expr, scalar_expr, columns, correct_operator) = match (left_columns.len(), right_columns.len()) { (1, 0) => (left, right, left_columns, op), @@ -478,7 +470,7 @@ impl<'a> PruningExpressionBuilder<'a> { df_schema, )?; let column = columns.iter().next().unwrap().clone(); - let field = match schema.column_with_name(&column.flat_name()) { + let field = match schema.column_with_name(column.name()) { Some((_, f)) => f, _ => { return Err(DataFusionError::Plan( @@ -501,16 +493,16 @@ impl<'a> PruningExpressionBuilder<'a> { self.op } - fn scalar_expr(&self) -> &Expr { + fn scalar_expr(&self) -> &Arc { &self.scalar_expr } - fn min_column_expr(&mut self) -> Result { + fn min_column_expr(&mut self) -> Result> { self.required_columns .min_column_expr(&self.column, &self.column_expr, self.field) } - fn max_column_expr(&mut self) -> Result { + fn max_column_expr(&mut self) -> Result> { self.required_columns .max_column_expr(&self.column, &self.column_expr, self.field) } @@ -529,64 +521,83 @@ impl<'a> PruningExpressionBuilder<'a> { /// /// More rewrite rules are still in progress. fn rewrite_expr_to_prunable( - column_expr: &Expr, + column_expr: &PhysicalExprRef, op: Operator, - scalar_expr: &Expr, + scalar_expr: &PhysicalExprRef, schema: DFSchema, -) -> Result<(Expr, Operator, Expr)> { +) -> Result<(PhysicalExprRef, Operator, PhysicalExprRef)> { if !is_compare_op(op) { return Err(DataFusionError::Plan( "rewrite_expr_to_prunable only support compare expression".to_string(), )); } - match column_expr { + let column_expr_any = column_expr.as_any(); + + if column_expr_any + .downcast_ref::() + .is_some() + { // `col op lit()` - Expr::Column(_) => Ok((column_expr.clone(), op, scalar_expr.clone())), + Ok((column_expr.clone(), op, scalar_expr.clone())) + } else if let Some(cast) = column_expr_any.downcast_ref::() { // `cast(col) op lit()` - Expr::Cast(Cast { expr, data_type }) => { - let from_type = expr.get_type(&schema)?; - verify_support_type_for_prune(&from_type, data_type)?; - let (left, op, right) = - rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?; - Ok((cast(left, data_type.clone()), op, right)) - } + let arrow_schema: SchemaRef = schema.clone().into(); + let from_type = cast.expr().data_type(&arrow_schema)?; + verify_support_type_for_prune(&from_type, cast.cast_type())?; + let (left, op, right) = + rewrite_expr_to_prunable(cast.expr(), op, scalar_expr, schema)?; + let left = Arc::new(phys_expr::CastExpr::new( + left, + cast.cast_type().clone(), + DEFAULT_CAST_OPTIONS, + )); + Ok((left, op, right)) + } else if let Some(try_cast) = + column_expr_any.downcast_ref::() + { // `try_cast(col) op lit()` - Expr::TryCast(TryCast { expr, data_type }) => { - let from_type = expr.get_type(&schema)?; - verify_support_type_for_prune(&from_type, data_type)?; - let (left, op, right) = - rewrite_expr_to_prunable(expr, op, scalar_expr, schema)?; - Ok((try_cast(left, data_type.clone()), op, right)) - } + let arrow_schema: SchemaRef = schema.clone().into(); + let from_type = try_cast.expr().data_type(&arrow_schema)?; + verify_support_type_for_prune(&from_type, try_cast.cast_type())?; + let (left, op, right) = + rewrite_expr_to_prunable(try_cast.expr(), op, scalar_expr, schema)?; + let left = Arc::new(phys_expr::TryCastExpr::new( + left, + try_cast.cast_type().clone(), + )); + Ok((left, op, right)) + } else if let Some(neg) = column_expr_any.downcast_ref::() { // `-col > lit()` --> `col < -lit()` - Expr::Negative(c) => { - let (left, op, right) = rewrite_expr_to_prunable(c, op, scalar_expr, schema)?; - Ok((left, reverse_operator(op)?, Expr::Negative(Box::new(right)))) - } + let (left, op, right) = + rewrite_expr_to_prunable(neg.arg(), op, scalar_expr, schema)?; + let right = Arc::new(phys_expr::NegativeExpr::new(right)); + Ok((left, reverse_operator(op)?, right)) + } else if let Some(not) = column_expr_any.downcast_ref::() { // `!col = true` --> `col = !true` - Expr::Not(c) => { - if op != Operator::Eq && op != Operator::NotEq { - return Err(DataFusionError::Plan( - "Not with operator other than Eq / NotEq is not supported" - .to_string(), - )); - } - return match c.as_ref() { - Expr::Column(_) => Ok(( - c.as_ref().clone(), - reverse_operator(op)?, - Expr::Not(Box::new(scalar_expr.clone())), - )), - _ => Err(DataFusionError::Plan(format!( - "Not with complex expression {column_expr:?} is not supported" - ))), - }; + if op != Operator::Eq && op != Operator::NotEq { + return Err(DataFusionError::Plan( + "Not with operator other than Eq / NotEq is not supported".to_string(), + )); } - - _ => Err(DataFusionError::Plan(format!( + if not + .arg() + .as_any() + .downcast_ref::() + .is_some() + { + let left = not.arg().clone(); + let right = Arc::new(phys_expr::NotExpr::new(scalar_expr.clone())); + Ok((left, reverse_operator(op)?, right)) + } else { + Err(DataFusionError::Plan(format!( + "Not with complex expression {column_expr:?} is not supported" + ))) + } + } else { + Err(DataFusionError::Plan(format!( "column expression {column_expr:?} is not supported" - ))), + ))) } } @@ -629,14 +640,32 @@ fn verify_support_type_for_prune(from_type: &DataType, to_type: &DataType) -> Re /// replaces a column with an old name with a new name in an expression fn rewrite_column_expr( - e: Expr, - column_old: &Column, - column_new: &Column, -) -> Result { - rewrite_expr(e, |expr| match expr { - Expr::Column(c) if c == *column_old => Ok(Expr::Column(column_new.clone())), - _ => Ok(expr), - }) + e: Arc, + column_old: &phys_expr::Column, + column_new: &phys_expr::Column, +) -> Result> { + let mut rewriter = RewriteColumnExpr { + column_old, + column_new, + }; + e.transform_using(&mut rewriter) +} + +struct RewriteColumnExpr<'a> { + column_old: &'a phys_expr::Column, + column_new: &'a phys_expr::Column, +} + +impl<'a> TreeNodeRewriter> for RewriteColumnExpr<'a> { + fn mutate(&mut self, expr: Arc) -> Result> { + if let Some(column) = expr.as_any().downcast_ref::() { + if column == self.column_old { + return Ok(Arc::new(self.column_new.clone())); + } + } + + Ok(expr) + } } fn reverse_operator(op: Operator) -> Result { @@ -652,15 +681,15 @@ fn reverse_operator(op: Operator) -> Result { /// if the column may contain values, and false if definitely does not /// contain values fn build_single_column_expr( - column: &Column, + column: &phys_expr::Column, schema: &Schema, required_columns: &mut RequiredStatColumns, is_not: bool, // if true, treat as !col -) -> Option { - let field = schema.field_with_name(&column.name).ok()?; +) -> Option> { + let field = schema.field_with_name(column.name()).ok()?; if matches!(field.data_type(), &DataType::Boolean) { - let col_ref = Expr::Column(column.clone()); + let col_ref = Arc::new(column.clone()) as _; let min = required_columns .min_column_expr(column, &col_ref, field) @@ -675,11 +704,13 @@ fn build_single_column_expr( if is_not { // The only way we know a column couldn't match is if both the min and max are true // !(min && max) - Some(!(min.and(max))) + Some(Arc::new(phys_expr::NotExpr::new(Arc::new( + phys_expr::BinaryExpr::new(min, Operator::And, max), + )))) } else { // the only way we know a column couldn't match is if both the min and max are false // !(!min && !max) --> min || max - Some(min.or(max)) + Some(Arc::new(phys_expr::BinaryExpr::new(min, Operator::Or, max))) } } else { None @@ -691,24 +722,27 @@ fn build_single_column_expr( /// if the column may contain null, and false if definitely does not /// contain null. fn build_is_null_column_expr( - expr: &Expr, + expr: &Arc, schema: &Schema, required_columns: &mut RequiredStatColumns, -) -> Option { - match expr { - Expr::Column(ref col) => { - let field = schema.field_with_name(&col.name).ok()?; - - let null_count_field = &Field::new(field.name(), DataType::UInt64, true); - required_columns - .null_count_column_expr(col, expr, null_count_field) - .map(|null_count_column_expr| { - // IsNull(column) => null_count > 0 - null_count_column_expr.gt(lit::(0)) - }) - .ok() - } - _ => None, +) -> Option> { + if let Some(col) = expr.as_any().downcast_ref::() { + let field = schema.field_with_name(col.name()).ok()?; + + let null_count_field = &Field::new(field.name(), DataType::UInt64, true); + required_columns + .null_count_column_expr(col, expr, null_count_field) + .map(|null_count_column_expr| { + // IsNull(column) => null_count > 0 + Arc::new(phys_expr::BinaryExpr::new( + null_count_column_expr, + Operator::Gt, + Arc::new(phys_expr::Literal::new(ScalarValue::UInt64(Some(0)))), + )) as _ + }) + .ok() + } else { + None } } @@ -718,70 +752,95 @@ fn build_is_null_column_expr( /// /// Returns the pruning predicate as an [`Expr`] fn build_predicate_expression( - expr: &Expr, + expr: &Arc, schema: &Schema, required_columns: &mut RequiredStatColumns, -) -> Expr { +) -> Arc { // Returned for unsupported expressions. Such expressions are // converted to TRUE. - let unhandled = lit(true); + let unhandled = Arc::new(phys_expr::Literal::new(ScalarValue::Boolean(Some(true)))); // predicate expression can only be a binary expression - let (left, op, right) = match expr { - Expr::BinaryExpr(BinaryExpr { left, op, right }) => (left, *op, right), - Expr::IsNull(expr) => { - return build_is_null_column_expr(expr, schema, required_columns) - .unwrap_or(unhandled); - } - Expr::Column(col) => { - return build_single_column_expr(col, schema, required_columns, false) + let expr_any = expr.as_any(); + if let Some(is_null) = expr_any.downcast_ref::() { + return build_is_null_column_expr(is_null.arg(), schema, required_columns) + .unwrap_or(unhandled); + } + if let Some(col) = expr_any.downcast_ref::() { + return build_single_column_expr(col, schema, required_columns, false) + .unwrap_or(unhandled); + } + if let Some(not) = expr_any.downcast_ref::() { + // match !col (don't do so recursively) + if let Some(col) = not.arg().as_any().downcast_ref::() { + return build_single_column_expr(col, schema, required_columns, true) .unwrap_or(unhandled); + } else { + return unhandled; } - // match !col (don't do so recursively) - Expr::Not(input) => { - if let Expr::Column(col) = input.as_ref() { - return build_single_column_expr(col, schema, required_columns, true) - .unwrap_or(unhandled); + } + if let Some(in_list) = expr_any.downcast_ref::() { + if !in_list.list().is_empty() && in_list.list().len() < 20 { + let eq_op = if in_list.negated() { + Operator::NotEq } else { - return unhandled; - } - } - Expr::InList { - expr, - list, - negated, - } if !list.is_empty() && list.len() < 20 => { - let eq_fun = if *negated { Expr::not_eq } else { Expr::eq }; - let re_fun = if *negated { Expr::and } else { Expr::or }; - let change_expr = list + Operator::Eq + }; + let re_op = if in_list.negated() { + Operator::And + } else { + Operator::Or + }; + let change_expr = in_list + .list() .iter() - .map(|e| eq_fun(*expr.clone(), e.clone())) - .reduce(re_fun) + .cloned() + .map(|e| { + Arc::new(phys_expr::BinaryExpr::new( + in_list.expr().clone(), + eq_op, + e.clone(), + )) as _ + }) + .reduce(|a, b| Arc::new(phys_expr::BinaryExpr::new(a, re_op, b)) as _) .unwrap(); return build_predicate_expression(&change_expr, schema, required_columns); + } else { + return unhandled; } - _ => { + } + + let (left, op, right) = { + if let Some(bin_expr) = expr_any.downcast_ref::() { + ( + bin_expr.left().clone(), + *bin_expr.op(), + bin_expr.right().clone(), + ) + } else { return unhandled; } }; if op == Operator::And || op == Operator::Or { - let left_expr = build_predicate_expression(left, schema, required_columns); - let right_expr = build_predicate_expression(right, schema, required_columns); + let left_expr = build_predicate_expression(&left, schema, required_columns); + let right_expr = build_predicate_expression(&right, schema, required_columns); // simplify boolean expression if applicable let expr = match (&left_expr, op, &right_expr) { - (left, Operator::And, _) if *left == unhandled => right_expr, - (_, Operator::And, right) if *right == unhandled => left_expr, - (left, Operator::Or, right) if *left == unhandled || *right == unhandled => { + (left, Operator::And, _) if is_always_true(left) => right_expr, + (_, Operator::And, right) if is_always_true(right) => left_expr, + (left, Operator::Or, right) + if is_always_true(left) || is_always_true(right) => + { unhandled } - _ => binary_expr(left_expr, op, right_expr), + _ => Arc::new(phys_expr::BinaryExpr::new(left_expr, op, right_expr)), }; return expr; } let expr_builder = - PruningExpressionBuilder::try_new(left, right, op, schema, required_columns); + PruningExpressionBuilder::try_new(&left, &right, op, schema, required_columns); let mut expr_builder = match expr_builder { Ok(builder) => builder, // allow partial failure in predicate expression generation @@ -794,8 +853,10 @@ fn build_predicate_expression( build_statistics_expr(&mut expr_builder).unwrap_or(unhandled) } -fn build_statistics_expr(expr_builder: &mut PruningExpressionBuilder) -> Result { - let statistics_expr = +fn build_statistics_expr( + expr_builder: &mut PruningExpressionBuilder, +) -> Result> { + let statistics_expr: Arc = match expr_builder.op() { Operator::NotEq => { // column != literal => (min, max) = literal => @@ -803,42 +864,70 @@ fn build_statistics_expr(expr_builder: &mut PruningExpressionBuilder) -> Result< // min != literal || literal != max let min_column_expr = expr_builder.min_column_expr()?; let max_column_expr = expr_builder.max_column_expr()?; - min_column_expr - .not_eq(expr_builder.scalar_expr().clone()) - .or(expr_builder.scalar_expr().clone().not_eq(max_column_expr)) + Arc::new(phys_expr::BinaryExpr::new( + Arc::new(phys_expr::BinaryExpr::new( + min_column_expr, + Operator::NotEq, + expr_builder.scalar_expr().clone(), + )), + Operator::Or, + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.scalar_expr().clone(), + Operator::NotEq, + max_column_expr, + )), + )) } Operator::Eq => { // column = literal => (min, max) = literal => min <= literal && literal <= max // (column / 2) = 4 => (column_min / 2) <= 4 && 4 <= (column_max / 2) let min_column_expr = expr_builder.min_column_expr()?; let max_column_expr = expr_builder.max_column_expr()?; - min_column_expr - .lt_eq(expr_builder.scalar_expr().clone()) - .and(expr_builder.scalar_expr().clone().lt_eq(max_column_expr)) + Arc::new(phys_expr::BinaryExpr::new( + Arc::new(phys_expr::BinaryExpr::new( + min_column_expr, + Operator::LtEq, + expr_builder.scalar_expr().clone(), + )), + Operator::And, + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.scalar_expr().clone(), + Operator::LtEq, + max_column_expr, + )), + )) } Operator::Gt => { // column > literal => (min, max) > literal => max > literal - expr_builder - .max_column_expr()? - .gt(expr_builder.scalar_expr().clone()) + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.max_column_expr()?, + Operator::Gt, + expr_builder.scalar_expr().clone(), + )) } Operator::GtEq => { // column >= literal => (min, max) >= literal => max >= literal - expr_builder - .max_column_expr()? - .gt_eq(expr_builder.scalar_expr().clone()) + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.max_column_expr()?, + Operator::GtEq, + expr_builder.scalar_expr().clone(), + )) } Operator::Lt => { // column < literal => (min, max) < literal => min < literal - expr_builder - .min_column_expr()? - .lt(expr_builder.scalar_expr().clone()) + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.min_column_expr()?, + Operator::Lt, + expr_builder.scalar_expr().clone(), + )) } Operator::LtEq => { // column <= literal => (min, max) <= literal => min <= literal - expr_builder - .min_column_expr()? - .lt_eq(expr_builder.scalar_expr().clone()) + Arc::new(phys_expr::BinaryExpr::new( + expr_builder.min_column_expr()?, + Operator::LtEq, + expr_builder.scalar_expr().clone(), + )) } // other expressions are not supported _ => return Err(DataFusionError::Plan( @@ -867,8 +956,10 @@ mod tests { array::{BinaryArray, Int32Array, Int64Array, StringArray}, datatypes::{DataType, TimeUnit}, }; - use datafusion_common::ScalarValue; - use datafusion_expr::{cast, is_null}; + use datafusion_common::{ScalarValue, ToDFSchema}; + use datafusion_expr::{cast, is_null, try_cast, Expr}; + use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::execution_props::ExecutionProps; use std::collections::HashMap; #[derive(Debug)] @@ -1093,25 +1184,25 @@ mod tests { let required_columns = RequiredStatColumns::from(vec![ // min of original column s1, named s1_min ( - "s1".into(), + phys_expr::Column::new("s1", 1), StatisticsType::Min, Field::new("s1_min", DataType::Int32, true), ), // max of original column s2, named s2_max ( - "s2".into(), + phys_expr::Column::new("s2", 2), StatisticsType::Max, Field::new("s2_max", DataType::Int32, true), ), // max of original column s3, named s3_max ( - "s3".into(), + phys_expr::Column::new("s3", 3), StatisticsType::Max, Field::new("s3_max", DataType::Utf8, true), ), // min of original column s3, named s3_min ( - "s3".into(), + phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new("s3_min", DataType::Utf8, true), ), @@ -1163,7 +1254,7 @@ mod tests { // Request a record batch with of s1_min as a timestamp let required_columns = RequiredStatColumns::from(vec![( - "s3".into(), + phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new( "s1_min", @@ -1213,7 +1304,7 @@ mod tests { // Request a record batch with of s1_min as a timestamp let required_columns = RequiredStatColumns::from(vec![( - "s3".into(), + phys_expr::Column::new("s3", 3), StatisticsType::Min, Field::new("s1_min", DataType::Utf8, true), )]); @@ -1242,7 +1333,7 @@ mod tests { fn test_build_statistics_inconsistent_length() { // return an inconsistent length to the actual statistics arrays let required_columns = RequiredStatColumns::from(vec![( - "s1".into(), + phys_expr::Column::new("s1", 3), StatisticsType::Min, Field::new("s1_min", DataType::Int64, true), )]); @@ -1268,19 +1359,25 @@ mod tests { #[test] fn row_group_predicate_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min <= Int32(1) AND Int32(1) <= c1_max"; + let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1"; // test column on the left let expr = col("c1").eq(lit(1)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).eq(col("c1")); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1288,19 +1385,25 @@ mod tests { #[test] fn row_group_predicate_not_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min != Int32(1) OR Int32(1) != c1_max"; + let expected_expr = "c1_min@0 != 1 OR 1 != c1_max@1"; // test column on the left let expr = col("c1").not_eq(lit(1)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).not_eq(col("c1")); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1308,19 +1411,25 @@ mod tests { #[test] fn row_group_predicate_gt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max > Int32(1)"; + let expected_expr = "c1_max@0 > 1"; // test column on the left let expr = col("c1").gt(lit(1)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt(col("c1")); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1328,18 +1437,24 @@ mod tests { #[test] fn row_group_predicate_gt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_max >= Int32(1)"; + let expected_expr = "c1_max@0 >= 1"; // test column on the left let expr = col("c1").gt_eq(lit(1)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).lt_eq(col("c1")); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1347,19 +1462,25 @@ mod tests { #[test] fn row_group_predicate_lt() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min < Int32(1)"; + let expected_expr = "c1_min@0 < 1"; // test column on the left let expr = col("c1").lt(lit(1)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt(col("c1")); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1367,18 +1488,24 @@ mod tests { #[test] fn row_group_predicate_lt_eq() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "c1_min <= Int32(1)"; + let expected_expr = "c1_min@0 <= 1"; // test column on the left let expr = col("c1").lt_eq(lit(1)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(1).gt_eq(col("c1")); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1392,10 +1519,13 @@ mod tests { ]); // test AND operator joining supported c1 < 1 expression and unsupported c2 > c3 expression let expr = col("c1").lt(lit(1)).and(col("c2").lt(col("c3"))); - let expected_expr = "c1_min < Int32(1)"; - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let expected_expr = "c1_min@0 < 1"; + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1406,12 +1536,17 @@ mod tests { Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ]); - // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 expression - let expr = col("c1").lt(lit(1)).or(col("c2").modulus(lit(2))); - let expected_expr = "Boolean(true)"; - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + // test OR operator joining supported c1 < 1 expression and unsupported c2 % 2 = 0 expression + let expr = col("c1") + .lt(lit(1)) + .or(col("c2").modulus(lit(2)).eq(lit(0))); + let expected_expr = "true"; + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1419,12 +1554,15 @@ mod tests { #[test] fn row_group_predicate_not() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); - let expected_expr = "Boolean(true)"; + let expected_expr = "true"; let expr = col("c1").not(); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1432,12 +1570,15 @@ mod tests { #[test] fn row_group_predicate_not_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "NOT c1_min AND c1_max"; + let expected_expr = "NOT c1_min@0 AND c1_max@1"; let expr = col("c1").not(); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1445,12 +1586,15 @@ mod tests { #[test] fn row_group_predicate_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "c1_min OR c1_max"; + let expected_expr = "c1_min@0 OR c1_max@1"; let expr = col("c1"); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1458,14 +1602,17 @@ mod tests { #[test] fn row_group_predicate_lt_bool() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Boolean, false)]); - let expected_expr = "c1_min < Boolean(true)"; + let expected_expr = "c1_min@0 < true"; // DF doesn't support arithmetic on boolean columns so // this predicate will error when evaluated let expr = col("c1").lt(lit(true)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1481,26 +1628,38 @@ mod tests { let expr = col("c1") .lt(lit(1)) .and(col("c2").eq(lit(2)).or(col("c2").eq(lit(3)))); - let expected_expr = "c1_min < Int32(1) AND (c2_min <= Int32(2) AND Int32(2) <= c2_max OR c2_min <= Int32(3) AND Int32(3) <= c2_max)"; + let expected_expr = "c1_min@0 < 1 AND (c2_min@1 <= 2 AND 2 <= c2_max@2 OR c2_min@1 <= 3 AND 3 <= c2_max@2)"; let predicate_expr = - build_predicate_expression(&expr, &schema, &mut required_columns); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + test_build_predicate_expression(&expr, &schema, &mut required_columns); + assert_eq!(predicate_expr.to_string(), expected_expr); // c1 < 1 should add c1_min let c1_min_field = Field::new("c1_min", DataType::Int32, false); assert_eq!( required_columns.columns[0], - ("c1".into(), StatisticsType::Min, c1_min_field) + ( + phys_expr::Column::new("c1", 0), + StatisticsType::Min, + c1_min_field + ) ); // c2 = 2 should add c2_min and c2_max let c2_min_field = Field::new("c2_min", DataType::Int32, false); assert_eq!( required_columns.columns[1], - ("c2".into(), StatisticsType::Min, c2_min_field) + ( + phys_expr::Column::new("c2", 1), + StatisticsType::Min, + c2_min_field + ) ); let c2_max_field = Field::new("c2_max", DataType::Int32, false); assert_eq!( required_columns.columns[2], - ("c2".into(), StatisticsType::Max, c2_max_field) + ( + phys_expr::Column::new("c2", 1), + StatisticsType::Max, + c2_max_field + ) ); // c2 = 3 shouldn't add any new statistics fields assert_eq!(required_columns.columns.len(), 3); @@ -1520,10 +1679,13 @@ mod tests { list: vec![lit(1), lit(2), lit(3)], negated: false, }; - let expected_expr = "c1_min <= Int32(1) AND Int32(1) <= c1_max OR c1_min <= Int32(2) AND Int32(2) <= c1_max OR c1_min <= Int32(3) AND Int32(3) <= c1_max"; - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let expected_expr = "c1_min@0 <= 1 AND 1 <= c1_max@1 OR c1_min@0 <= 2 AND 2 <= c1_max@1 OR c1_min@0 <= 3 AND 3 <= c1_max@1"; + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1540,10 +1702,13 @@ mod tests { list: vec![], negated: false, }; - let expected_expr = "Boolean(true)"; - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let expected_expr = "true"; + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1560,12 +1725,15 @@ mod tests { list: vec![lit(1), lit(2), lit(3)], negated: true, }; - let expected_expr = "(c1_min != Int32(1) OR Int32(1) != c1_max) \ - AND (c1_min != Int32(2) OR Int32(2) != c1_max) \ - AND (c1_min != Int32(3) OR Int32(3) != c1_max)"; - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let expected_expr = "(c1_min@0 != 1 OR 1 != c1_max@1) \ + AND (c1_min@0 != 2 OR 2 != c1_max@1) \ + AND (c1_min@0 != 3 OR 3 != c1_max@1)"; + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1574,35 +1742,47 @@ mod tests { fn row_group_predicate_cast() -> Result<()> { let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); let expected_expr = - "CAST(c1_min AS Int64) <= Int64(1) AND Int64(1) <= CAST(c1_max AS Int64)"; + "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64)"; // test column on the left let expr = cast(col("c1"), DataType::Int64).eq(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).eq(cast(col("c1"), DataType::Int64)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); - let expected_expr = "TRY_CAST(c1_max AS Int64) > Int64(1)"; + let expected_expr = "TRY_CAST(c1_max@0 AS Int64) > 1"; // test column on the left let expr = try_cast(col("c1"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(1)))); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); // test column on the right let expr = lit(ScalarValue::Int64(Some(1))).lt(try_cast(col("c1"), DataType::Int64)); - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1620,10 +1800,13 @@ mod tests { ], negated: false, }; - let expected_expr = "CAST(c1_min AS Int64) <= Int64(1) AND Int64(1) <= CAST(c1_max AS Int64) OR CAST(c1_min AS Int64) <= Int64(2) AND Int64(2) <= CAST(c1_max AS Int64) OR CAST(c1_min AS Int64) <= Int64(3) AND Int64(3) <= CAST(c1_max AS Int64)"; - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + let expected_expr = "CAST(c1_min@0 AS Int64) <= 1 AND 1 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 2 AND 2 <= CAST(c1_max@1 AS Int64) OR CAST(c1_min@0 AS Int64) <= 3 AND 3 <= CAST(c1_max@1 AS Int64)"; + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); let expr = Expr::InList { expr: Box::new(cast(col("c1"), DataType::Int64)), @@ -1635,12 +1818,15 @@ mod tests { negated: true, }; let expected_expr = - "(CAST(c1_min AS Int64) != Int64(1) OR Int64(1) != CAST(c1_max AS Int64)) \ - AND (CAST(c1_min AS Int64) != Int64(2) OR Int64(2) != CAST(c1_max AS Int64)) \ - AND (CAST(c1_min AS Int64) != Int64(3) OR Int64(3) != CAST(c1_max AS Int64))"; - let predicate_expr = - build_predicate_expression(&expr, &schema, &mut RequiredStatColumns::new()); - assert_eq!(format!("{predicate_expr:?}"), expected_expr); + "(CAST(c1_min@0 AS Int64) != 1 OR 1 != CAST(c1_max@1 AS Int64)) \ + AND (CAST(c1_min@0 AS Int64) != 2 OR 2 != CAST(c1_max@1 AS Int64)) \ + AND (CAST(c1_min@0 AS Int64) != 3 OR 3 != CAST(c1_max@1 AS Int64))"; + let predicate_expr = test_build_predicate_expression( + &expr, + &schema, + &mut RequiredStatColumns::new(), + ); + assert_eq!(predicate_expr.to_string(), expected_expr); Ok(()) } @@ -1655,6 +1841,7 @@ mod tests { )])); // s1 > 5 let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); + let expr = logical2physical(&expr, &schema); // If the data is written by spark, the physical data type is INT32 in the parquet // So we use the INT32 type of statistic. let statistics = TestStatistics::new().with( @@ -1672,6 +1859,7 @@ mod tests { // with cast column to other type let expr = cast(col("s1"), DataType::Decimal128(14, 3)) .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); + let expr = logical2physical(&expr, &schema); let statistics = TestStatistics::new().with( "s1", ContainerStats::new_i32( @@ -1687,6 +1875,7 @@ mod tests { // with try cast column to other type let expr = try_cast(col("s1"), DataType::Decimal128(14, 3)) .gt(lit(ScalarValue::Decimal128(Some(5000), 14, 3))); + let expr = logical2physical(&expr, &schema); let statistics = TestStatistics::new().with( "s1", ContainerStats::new_i32( @@ -1707,6 +1896,7 @@ mod tests { )])); // s1 > 5 let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); + let expr = logical2physical(&expr, &schema); // If the data is written by spark, the physical data type is INT64 in the parquet // So we use the INT32 type of statistic. let statistics = TestStatistics::new().with( @@ -1729,6 +1919,7 @@ mod tests { )])); // s1 > 5 let expr = col("s1").gt(lit(ScalarValue::Decimal128(Some(500), 23, 2))); + let expr = logical2physical(&expr, &schema); let statistics = TestStatistics::new().with( "s1", ContainerStats::new_decimal128( @@ -1753,6 +1944,7 @@ mod tests { // Prune using s2 > 5 let expr = col("s2").gt(lit(5)); + let expr = logical2physical(&expr, &schema); let statistics = TestStatistics::new().with( "s2", @@ -1774,6 +1966,7 @@ mod tests { // filter with cast let expr = cast(col("s2"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(5)))); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); let expected = vec![false, true, true, true]; @@ -1786,6 +1979,7 @@ mod tests { // Prune using s2 != 'M' let expr = col("s1").not_eq(lit("M")); + let expr = logical2physical(&expr, &schema); let statistics = TestStatistics::new().with( "s1", @@ -1840,12 +2034,34 @@ mod tests { (schema, statistics, expected_true, expected_false) } + #[test] + fn prune_bool_const_expr() { + let (schema, statistics, _, _) = bool_setup(); + + // true + let expr = lit(true); + let expr = logical2physical(&expr, &schema); + let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, vec![true, true, true, true, true]); + + // false + // constant literals that do NOT refer to any columns are currently not evaluated at all, hence the result is + // "all true" + let expr = lit(false); + let expr = logical2physical(&expr, &schema); + let p = PruningPredicate::try_new(expr, schema).unwrap(); + let result = p.prune(&statistics).unwrap(); + assert_eq!(result, vec![true, true, true, true, true]); + } + #[test] fn prune_bool_column() { let (schema, statistics, expected_true, _) = bool_setup(); // b1 let expr = col("b1"); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_true); @@ -1857,6 +2073,7 @@ mod tests { // !b1 let expr = col("b1").not(); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_false); @@ -1868,6 +2085,7 @@ mod tests { // b1 = true let expr = col("b1").eq(lit(true)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_true); @@ -1879,6 +2097,7 @@ mod tests { // !b1 = true let expr = col("b1").not().eq(lit(true)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_false); @@ -1920,12 +2139,14 @@ mod tests { // i > 0 let expr = col("i").gt(lit(0)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); // -i < 0 let expr = Expr::Negative(Box::new(col("i"))).lt(lit(0)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -1945,12 +2166,14 @@ mod tests { // i <= 0 let expr = col("i").lt_eq(lit(0)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); // -i >= 0 let expr = Expr::Negative(Box::new(col("i"))).gt_eq(lit(0)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -1970,26 +2193,30 @@ mod tests { // cast(i as utf8) <= 0 let expr = cast(col("i"), DataType::Utf8).lt_eq(lit("0")); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); // try_cast(i as utf8) <= 0 let expr = try_cast(col("i"), DataType::Utf8).lt_eq(lit("0")); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); // cast(-i as utf8) >= 0 let expr = - Expr::Negative(Box::new(cast(col("i"), DataType::Utf8))).gt_eq(lit("0")); + cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); // try_cast(-i as utf8) >= 0 let expr = - Expr::Negative(Box::new(try_cast(col("i"), DataType::Utf8))).gt_eq(lit("0")); + try_cast(Expr::Negative(Box::new(col("i"))), DataType::Utf8).gt_eq(lit("0")); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2009,6 +2236,7 @@ mod tests { // i = 0 let expr = col("i").eq(lit(0)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2027,11 +2255,13 @@ mod tests { let expected_ret = vec![true, false, false, true, false]; let expr = cast(col("i"), DataType::Int64).eq(lit(0i64)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); let expr = try_cast(col("i"), DataType::Int64).eq(lit(0i64)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2053,6 +2283,7 @@ mod tests { let expected_ret = vec![true, true, true, true, true]; let expr = cast(col("i"), DataType::Utf8).eq(lit("0")); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2072,12 +2303,14 @@ mod tests { // i > -1 let expr = col("i").gt(lit(-1)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); // -i < 1 let expr = Expr::Negative(Box::new(col("i"))).lt(lit(1)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2093,6 +2326,7 @@ mod tests { // i IS NULL, no null statistics let expr = col("i").is_null(); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2113,6 +2347,7 @@ mod tests { // i IS NULL, with actual null statistcs let expr = col("i").is_null(); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2126,12 +2361,14 @@ mod tests { // i > int64(0) let expr = col("i").gt(cast(lit(ScalarValue::Int64(Some(0))), DataType::Int32)); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); // cast(i as int64) > int64(0) let expr = cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2139,6 +2376,7 @@ mod tests { // try_cast(i as int64) > int64(0) let expr = try_cast(col("i"), DataType::Int64).gt(lit(ScalarValue::Int64(Some(0)))); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2146,6 +2384,7 @@ mod tests { // `-cast(i as int64) < 0` convert to `cast(i as int64) > -0` let expr = Expr::Negative(Box::new(cast(col("i"), DataType::Int64))) .lt(lit(ScalarValue::Int64(Some(0)))); + let expr = logical2physical(&expr, &schema); let p = PruningPredicate::try_new(expr, schema).unwrap(); let result = p.prune(&statistics).unwrap(); assert_eq!(result, expected_ret); @@ -2154,10 +2393,13 @@ mod tests { #[test] fn test_rewrite_expr_to_prunable() { let schema = Schema::new(vec![Field::new("a", DataType::Int32, true)]); - let df_schema = DFSchema::try_from(schema).unwrap(); + let df_schema = DFSchema::try_from(schema.clone()).unwrap(); + // column op lit let left_input = col("a"); + let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int32(Some(12))); + let right_input = logical2physical(&right_input, &schema); let (result_left, _, result_right) = rewrite_expr_to_prunable( &left_input, Operator::Eq, @@ -2165,11 +2407,14 @@ mod tests { df_schema.clone(), ) .unwrap(); - assert_eq!(result_left, left_input); - assert_eq!(result_right, right_input); + assert_eq!(result_left.to_string(), left_input.to_string()); + assert_eq!(result_right.to_string(), right_input.to_string()); + // cast op lit let left_input = cast(col("a"), DataType::Decimal128(20, 3)); + let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Decimal128(Some(12), 20, 3)); + let right_input = logical2physical(&right_input, &schema); let (result_left, _, result_right) = rewrite_expr_to_prunable( &left_input, Operator::Gt, @@ -2177,16 +2422,20 @@ mod tests { df_schema.clone(), ) .unwrap(); - assert_eq!(result_left, left_input); - assert_eq!(result_right, right_input); + assert_eq!(result_left.to_string(), left_input.to_string()); + assert_eq!(result_right.to_string(), right_input.to_string()); + // try_cast op lit let left_input = try_cast(col("a"), DataType::Int64); + let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); + let right_input = logical2physical(&right_input, &schema); let (result_left, _, result_right) = rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema) .unwrap(); - assert_eq!(result_left, left_input); - assert_eq!(result_right, right_input); + assert_eq!(result_left.to_string(), left_input.to_string()); + assert_eq!(result_right.to_string(), right_input.to_string()); + // TODO: add test for other case and op } @@ -2195,9 +2444,11 @@ mod tests { // cast string value to numeric value // this cast is not supported let schema = Schema::new(vec![Field::new("a", DataType::Utf8, true)]); - let df_schema = DFSchema::try_from(schema).unwrap(); + let df_schema = DFSchema::try_from(schema.clone()).unwrap(); let left_input = cast(col("a"), DataType::Int64); + let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); + let right_input = logical2physical(&right_input, &schema); let result = rewrite_expr_to_prunable( &left_input, Operator::Gt, @@ -2205,12 +2456,30 @@ mod tests { df_schema.clone(), ); assert!(result.is_err()); + // other expr let left_input = is_null(col("a")); + let left_input = logical2physical(&left_input, &schema); let right_input = lit(ScalarValue::Int64(Some(12))); + let right_input = logical2physical(&right_input, &schema); let result = rewrite_expr_to_prunable(&left_input, Operator::Gt, &right_input, df_schema); assert!(result.is_err()); // TODO: add other negative test for other case and op } + + fn test_build_predicate_expression( + expr: &Expr, + schema: &Schema, + required_columns: &mut RequiredStatColumns, + ) -> Arc { + let expr = logical2physical(expr, schema); + build_predicate_expression(&expr, schema, required_columns) + } + + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } } diff --git a/datafusion/core/src/physical_plan/file_format/parquet.rs b/datafusion/core/src/physical_plan/file_format/parquet.rs index e2d8cc94dcce..a12672756abc 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet.rs @@ -18,6 +18,7 @@ //! Execution plan for reading Parquet files use arrow::datatypes::{DataType, SchemaRef}; +use datafusion_physical_expr::PhysicalExpr; use fmt::Debug; use std::any::Any; use std::cmp::min; @@ -47,7 +48,6 @@ use crate::{ }; use arrow::error::ArrowError; use bytes::Bytes; -use datafusion_expr::Expr; use futures::future::BoxFuture; use futures::{FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use itertools::Itertools; @@ -97,7 +97,7 @@ pub struct ParquetExec { /// Execution metrics metrics: ExecutionPlanMetricsSet, /// Optional predicate for row filtering during parquet scan - predicate: Option>, + predicate: Option>, /// Optional predicate for pruning row groups pruning_predicate: Option>, /// Optional predicate for pruning pages @@ -112,7 +112,7 @@ impl ParquetExec { /// Create a new Parquet reader execution plan provided file list and schema. pub fn new( base_config: FileScanConfig, - predicate: Option, + predicate: Option>, metadata_size_hint: Option, ) -> Self { debug!("Creating ParquetExec, files: {:?}, projection {:?}, predicate: {:?}, limit: {:?}", @@ -151,9 +151,6 @@ impl ParquetExec { } }); - // Save original predicate - let predicate = predicate.map(Arc::new); - let (projected_schema, projected_statistics) = base_config.project(); Self { @@ -462,7 +459,7 @@ struct ParquetOpener { projection: Arc<[usize]>, batch_size: usize, limit: Option, - predicate: Option>, + predicate: Option>, pruning_predicate: Option>, page_pruning_predicate: Option>, table_schema: SchemaRef, @@ -511,6 +508,7 @@ impl FileOpener for ParquetOpener { .await?; let adapted_projections = schema_adapter.map_projections(builder.schema(), &projection)?; + // let predicate = predicate.map(|p| reassign_predicate_columns(p, builder.schema(), true)).transpose()?; let mask = ProjectionMask::roots( builder.parquet_schema(), @@ -520,7 +518,7 @@ impl FileOpener for ParquetOpener { // Filter pushdown: evaluate predicates during scan if let Some(predicate) = pushdown_filters.then_some(predicate).flatten() { let row_filter = row_filter::build_row_filter( - predicate.as_ref(), + &predicate, builder.schema().as_ref(), table_schema.as_ref(), builder.metadata(), @@ -823,9 +821,11 @@ mod tests { datatypes::{DataType, Field}, }; use chrono::{TimeZone, Utc}; - use datafusion_common::assert_contains; use datafusion_common::ScalarValue; - use datafusion_expr::{col, lit, when}; + use datafusion_common::{assert_contains, ToDFSchema}; + use datafusion_expr::{col, lit, when, Expr}; + use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::execution_props::ExecutionProps; use futures::StreamExt; use object_store::local::LocalFileSystem; use object_store::path::Path; @@ -917,6 +917,9 @@ mod tests { let (meta, _files) = store_parquet(batches, multi_page).await.unwrap(); let file_groups = meta.into_iter().map(Into::into).collect(); + // set up predicate (this is normally done by a layer higher up) + let predicate = predicate.map(|p| logical2physical(&p, &file_schema)); + // prepare the scan let mut parquet_exec = ParquetExec::new( FileScanConfig { @@ -1863,7 +1866,7 @@ mod tests { "pruning_predicate=c1_min@0 != bar OR bar != c1_max@1" ); - assert_contains!(&display, r#"predicate=c1 != Utf8("bar")"#); + assert_contains!(&display, r#"predicate=c1@0 != bar"#); assert_contains!(&display, "projection=[c1]"); } @@ -1903,7 +1906,8 @@ mod tests { // but does still has a pushdown down predicate let predicate = rt.parquet_exec.predicate.as_ref(); - assert_eq!(predicate.unwrap().as_ref(), &filter); + let filter_phys = logical2physical(&filter, rt.parquet_exec.schema().as_ref()); + assert_eq!(predicate.unwrap().to_string(), filter_phys.to_string()); } #[tokio::test] @@ -2256,4 +2260,10 @@ mod tests { Ok(()) } + + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } } diff --git a/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs index 585f0c886245..0853caabe12a 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/page_filter.rs @@ -23,9 +23,9 @@ use arrow::array::{ }; use arrow::datatypes::DataType; use arrow::{array::ArrayRef, datatypes::SchemaRef, error::ArrowError}; -use datafusion_common::{Column, DataFusionError, Result}; -use datafusion_expr::Expr; -use datafusion_optimizer::utils::split_conjunction; +use datafusion_common::{DataFusionError, Result}; +use datafusion_physical_expr::expressions::Column; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use log::{debug, trace}; use parquet::schema::types::ColumnDescriptor; use parquet::{ @@ -107,7 +107,7 @@ pub(crate) struct PagePruningPredicate { impl PagePruningPredicate { /// Create a new [`PagePruningPredicate`] - pub fn try_new(expr: &Expr, schema: SchemaRef) -> Result { + pub fn try_new(expr: &Arc, schema: SchemaRef) -> Result { let predicates = split_conjunction(expr) .into_iter() .filter_map(|predicate| { @@ -253,7 +253,8 @@ fn find_column_index( if let Some(found_required_column) = found_required_column.as_ref() { // make sure it is the same name we have seen previously assert_eq!( - column.name, found_required_column.name, + column.name(), + found_required_column.name(), "Unexpected multi column predicate" ); } else { @@ -272,11 +273,11 @@ fn find_column_index( .columns() .iter() .enumerate() - .find(|(_idx, c)| c.column_descr().name() == column.name) + .find(|(_idx, c)| c.column_descr().name() == column.name()) .map(|(idx, _c)| idx); if col_idx.is_none() { - trace!("Can not find column {} in row group meta", column.name); + trace!("Can not find column {} in row group meta", column.name()); } col_idx @@ -506,11 +507,11 @@ macro_rules! get_min_max_values_for_page_index { } impl<'a> PruningStatistics for PagesPruningStatistics<'a> { - fn min_values(&self, _column: &Column) -> Option { + fn min_values(&self, _column: &datafusion_common::Column) -> Option { get_min_max_values_for_page_index!(self, min) } - fn max_values(&self, _column: &Column) -> Option { + fn max_values(&self, _column: &datafusion_common::Column) -> Option { get_min_max_values_for_page_index!(self, max) } @@ -518,7 +519,7 @@ impl<'a> PruningStatistics for PagesPruningStatistics<'a> { self.col_offset_indexes.len() } - fn null_counts(&self, _column: &Column) -> Option { + fn null_counts(&self, _column: &datafusion_common::Column) -> Option { match self.col_page_indexes { Index::NONE => None, Index::BOOLEAN(index) => Some(Arc::new(Int64Array::from_iter( diff --git a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs index 92c0ddc8724e..e1feafec1588 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/row_filter.rs @@ -20,14 +20,15 @@ use arrow::datatypes::{DataType, Schema}; use arrow::error::{ArrowError, Result as ArrowResult}; use arrow::record_batch::RecordBatch; use datafusion_common::cast::as_boolean_array; -use datafusion_common::{Column, DataFusionError, Result, ScalarValue, ToDFSchema}; -use datafusion_expr::expr_rewriter::{ExprRewritable, ExprRewriter, RewriteRecursion}; +use datafusion_common::{DataFusionError, Result, ScalarValue}; +use datafusion_physical_expr::expressions::{Column, Literal}; +use datafusion_physical_expr::rewrite::{ + RewriteRecursion, TreeNodeRewritable, TreeNodeRewriter, +}; +use datafusion_physical_expr::utils::reassign_predicate_columns; use std::collections::BTreeSet; -use datafusion_expr::Expr; -use datafusion_optimizer::utils::split_conjunction; -use datafusion_physical_expr::execution_props::ExecutionProps; -use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; +use datafusion_physical_expr::{split_conjunction, PhysicalExpr}; use parquet::arrow::arrow_reader::{ArrowPredicate, RowFilter}; use parquet::arrow::ProjectionMask; use parquet::file::metadata::ParquetMetaData; @@ -87,13 +88,8 @@ impl DatafusionArrowPredicate { rows_filtered: metrics::Count, time: metrics::Time, ) -> Result { - let props = ExecutionProps::default(); - - let schema = schema.project(&candidate.projection)?; - let df_schema = schema.clone().to_dfschema()?; - - let physical_expr = - create_physical_expr(&candidate.expr, &df_schema, &schema, &props)?; + let schema = Arc::new(schema.project(&candidate.projection)?); + let physical_expr = reassign_predicate_columns(candidate.expr, &schema, true)?; // ArrowPredicate::evaluate is passed columns in the order they appear in the file // If the predicate has multiple columns, we therefore must project the columns based @@ -153,7 +149,7 @@ impl ArrowPredicate for DatafusionArrowPredicate { /// expression as well as data to estimate the cost of evaluating /// the resulting expression. pub(crate) struct FilterCandidate { - expr: Expr, + expr: Arc, required_bytes: usize, can_use_index: bool, projection: Vec, @@ -167,7 +163,7 @@ pub(crate) struct FilterCandidate { /// and any given file may or may not contain all columns in the merged schema. If a particular column is not present /// we replace the column expression with a literal expression that produces a null value. struct FilterCandidateBuilder<'a> { - expr: Expr, + expr: Arc, file_schema: &'a Schema, table_schema: &'a Schema, required_column_indices: BTreeSet, @@ -176,7 +172,11 @@ struct FilterCandidateBuilder<'a> { } impl<'a> FilterCandidateBuilder<'a> { - pub fn new(expr: Expr, file_schema: &'a Schema, table_schema: &'a Schema) -> Self { + pub fn new( + expr: Arc, + file_schema: &'a Schema, + table_schema: &'a Schema, + ) -> Self { Self { expr, file_schema, @@ -192,7 +192,7 @@ impl<'a> FilterCandidateBuilder<'a> { metadata: &ParquetMetaData, ) -> Result> { let expr = self.expr.clone(); - let expr = expr.rewrite(&mut self)?; + let expr = expr.transform_using(&mut self)?; if self.non_primitive_columns || self.projected_columns { Ok(None) @@ -211,16 +211,16 @@ impl<'a> FilterCandidateBuilder<'a> { } } -impl<'a> ExprRewriter for FilterCandidateBuilder<'a> { - fn pre_visit(&mut self, expr: &Expr) -> Result { - if let Expr::Column(column) = expr { - if let Ok(idx) = self.file_schema.index_of(&column.name) { +impl<'a> TreeNodeRewriter> for FilterCandidateBuilder<'a> { + fn pre_visit(&mut self, node: &Arc) -> Result { + if let Some(column) = node.as_any().downcast_ref::() { + if let Ok(idx) = self.file_schema.index_of(column.name()) { self.required_column_indices.insert(idx); if DataType::is_nested(self.file_schema.field(idx).data_type()) { self.non_primitive_columns = true; } - } else if self.table_schema.index_of(&column.name).is_err() { + } else if self.table_schema.index_of(column.name()).is_err() { // If the column does not exist in the (un-projected) table schema then // it must be a projected column. self.projected_columns = true; @@ -229,15 +229,15 @@ impl<'a> ExprRewriter for FilterCandidateBuilder<'a> { Ok(RewriteRecursion::Continue) } - fn mutate(&mut self, expr: Expr) -> Result { - if let Expr::Column(Column { name, .. }) = &expr { - if self.file_schema.field_with_name(name).is_err() { + fn mutate(&mut self, expr: Arc) -> Result> { + if let Some(column) = expr.as_any().downcast_ref::() { + if self.file_schema.field_with_name(column.name()).is_err() { // the column expr must be in the table schema - return match self.table_schema.field_with_name(name) { + return match self.table_schema.field_with_name(column.name()) { Ok(field) => { // return the null value corresponding to the data type let null_value = ScalarValue::try_from(field.data_type())?; - Ok(Expr::Literal(null_value)) + Ok(Arc::new(Literal::new(null_value))) } Err(e) => { // If the column is not in the table schema, should throw the error @@ -308,7 +308,7 @@ fn columns_sorted( /// Build a [`RowFilter`] from the given predicate `Expr` pub fn build_row_filter( - expr: &Expr, + expr: &Arc, file_schema: &Schema, table_schema: &Schema, metadata: &ParquetMetaData, @@ -391,36 +391,14 @@ pub fn build_row_filter( mod test { use super::*; use arrow::datatypes::Field; - use datafusion_expr::{cast, col, lit}; + use datafusion_common::ToDFSchema; + use datafusion_expr::{cast, col, lit, Expr}; + use datafusion_physical_expr::create_physical_expr; + use datafusion_physical_expr::execution_props::ExecutionProps; use parquet::arrow::parquet_to_arrow_schema; use parquet::file::reader::{FileReader, SerializedFileReader}; use rand::prelude::*; - // Assume a column expression for a column not in the table schema is a projected column and ignore it - #[test] - #[should_panic(expected = "building candidate failed")] - fn test_filter_candidate_builder_ignore_projected_columns() { - let testdata = crate::test_util::parquet_test_data(); - let file = std::fs::File::open(format!("{testdata}/alltypes_plain.parquet")) - .expect("opening file"); - - let reader = SerializedFileReader::new(file).expect("creating reader"); - - let metadata = reader.metadata(); - - let table_schema = - parquet_to_arrow_schema(metadata.file_metadata().schema_descr(), None) - .expect("parsing schema"); - - let expr = col("projected_column").eq(lit("value")); - - let candidate = FilterCandidateBuilder::new(expr, &table_schema, &table_schema) - .build(metadata) - .expect("building candidate failed"); - - assert!(candidate.is_none()); - } - // We should ignore predicate that read non-primitive columns #[test] fn test_filter_candidate_builder_ignore_complex_types() { @@ -437,6 +415,7 @@ mod test { .expect("parsing schema"); let expr = col("int64_list").is_not_null(); + let expr = logical2physical(&expr, &table_schema); let candidate = FilterCandidateBuilder::new(expr, &table_schema, &table_schema) .build(metadata) @@ -467,8 +446,11 @@ mod test { // The parquet file with `file_schema` just has `bigint_col` and `float_col` column, and don't have the `int_col` let expr = col("bigint_col").eq(cast(col("int_col"), DataType::Int64)); + let expr = logical2physical(&expr, &table_schema); let expected_candidate_expr = col("bigint_col").eq(cast(lit(ScalarValue::Int32(None)), DataType::Int64)); + let expected_candidate_expr = + logical2physical(&expected_candidate_expr, &table_schema); let candidate = FilterCandidateBuilder::new(expr, &file_schema, &table_schema) .build(metadata) @@ -476,7 +458,10 @@ mod test { assert!(candidate.is_some()); - assert_eq!(candidate.unwrap().expr, expected_candidate_expr); + assert_eq!( + candidate.unwrap().expr.to_string(), + expected_candidate_expr.to_string() + ); } #[test] @@ -496,4 +481,10 @@ mod test { assert_eq!(projection, remapped) } } + + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } } diff --git a/datafusion/core/src/physical_plan/file_format/parquet/row_groups.rs b/datafusion/core/src/physical_plan/file_format/parquet/row_groups.rs index 101102c927f1..4ba60f08524f 100644 --- a/datafusion/core/src/physical_plan/file_format/parquet/row_groups.rs +++ b/datafusion/core/src/physical_plan/file_format/parquet/row_groups.rs @@ -244,7 +244,10 @@ mod tests { use arrow::datatypes::DataType::Decimal128; use arrow::datatypes::Schema; use arrow::datatypes::{DataType, Field}; - use datafusion_expr::{cast, col, lit}; + use datafusion_common::ToDFSchema; + use datafusion_expr::{cast, col, lit, Expr}; + use datafusion_physical_expr::execution_props::ExecutionProps; + use datafusion_physical_expr::{create_physical_expr, PhysicalExpr}; use parquet::basic::LogicalType; use parquet::data_type::{ByteArray, FixedLenByteArray}; use parquet::{ @@ -258,8 +261,9 @@ mod tests { fn row_group_pruning_predicate_simple_expr() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let expr = col("c1").gt(lit(15)); let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expr = col("c1").gt(lit(15)); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); let schema_descr = get_test_schema_descr(vec![( @@ -290,8 +294,9 @@ mod tests { fn row_group_pruning_predicate_missing_stats() { use datafusion_expr::{col, lit}; // int > 1 => c1_max > 1 - let expr = col("c1").gt(lit(15)); let schema = Schema::new(vec![Field::new("c1", DataType::Int32, false)]); + let expr = col("c1").gt(lit(15)); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); @@ -324,12 +329,15 @@ mod tests { fn row_group_pruning_predicate_partial_expr() { use datafusion_expr::{col, lit}; // test row group predicate with partially supported expression - // int > 1 and int % 2 => c1_max > 1 and true - let expr = col("c1").gt(lit(15)).and(col("c2").modulus(lit(2))); + // (int > 1) and ((int % 2) = 0) => c1_max > 1 and true let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Int32, false), ])); + let expr = col("c1") + .gt(lit(15)) + .and(col("c2").modulus(lit(2)).eq(lit(0))); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema.clone()).unwrap(); let schema_descr = get_test_schema_descr(vec![ @@ -362,7 +370,10 @@ mod tests { // if conditions in predicate are joined with OR and an unsupported expression is used // this bypasses the entire predicate expression and no row groups are filtered out - let expr = col("c1").gt(lit(15)).or(col("c2").modulus(lit(2))); + let expr = col("c1") + .gt(lit(15)) + .or(col("c2").modulus(lit(2)).eq(lit(0))); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); // if conditions in predicate are joined with OR and an unsupported expression is used @@ -399,11 +410,12 @@ mod tests { fn row_group_pruning_predicate_null_expr() { use datafusion_expr::{col, lit}; // int > 1 and IsNull(bool) => c1_max > 1 and bool_null_count > 0 - let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let expr = col("c1").gt(lit(15)).and(col("c2").is_null()); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); @@ -421,13 +433,14 @@ mod tests { // test row group predicate with an unknown (Null) expr // // int > 1 and bool = NULL => c1_max > 1 and null - let expr = col("c1") - .gt(lit(15)) - .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); let schema = Arc::new(Schema::new(vec![ Field::new("c1", DataType::Int32, false), Field::new("c2", DataType::Boolean, false), ])); + let expr = col("c1") + .gt(lit(15)) + .and(col("c2").eq(lit(ScalarValue::Boolean(None)))); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, schema).unwrap(); let groups = gen_row_group_meta_data_for_pruning_predicate(); @@ -448,7 +461,6 @@ mod tests { // INT32: c1 > 5, the c1 is decimal(9,2) // The type of scalar value if decimal(9,2), don't need to do cast - let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); let schema = Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 2), false)]); let schema_descr = get_test_schema_descr(vec![( @@ -462,6 +474,8 @@ mod tests { Some(2), None, )]); + let expr = col("c1").gt(lit(ScalarValue::Decimal128(Some(500), 9, 2))); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); let rgm1 = get_row_group_meta_data( @@ -503,10 +517,6 @@ mod tests { // The c1 type is decimal(9,0) in the parquet file, and the type of scalar is decimal(5,2). // We should convert all type to the coercion type, which is decimal(11,2) // The decimal of arrow is decimal(5,2), the decimal of parquet is decimal(9,0) - let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast( - lit(ScalarValue::Decimal128(Some(500), 5, 2)), - Decimal128(11, 2), - )); let schema = Schema::new(vec![Field::new("c1", DataType::Decimal128(9, 0), false)]); let schema_descr = get_test_schema_descr(vec![( @@ -520,6 +530,11 @@ mod tests { Some(0), None, )]); + let expr = cast(col("c1"), DataType::Decimal128(11, 2)).gt(cast( + lit(ScalarValue::Decimal128(Some(500), 5, 2)), + Decimal128(11, 2), + )); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); let rgm1 = get_row_group_meta_data( @@ -564,7 +579,6 @@ mod tests { ); // INT64: c1 < 5, the c1 is decimal(18,2) - let expr = col("c1").lt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); let schema = Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); let schema_descr = get_test_schema_descr(vec![( @@ -578,6 +592,8 @@ mod tests { Some(2), None, )]); + let expr = col("c1").lt(lit(ScalarValue::Decimal128(Some(500), 18, 2))); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); let rgm1 = get_row_group_meta_data( @@ -616,9 +632,6 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); - // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); - let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let schema_descr = get_test_schema_descr(vec![( "c1", PhysicalType::FIXED_LEN_BYTE_ARRAY, @@ -630,6 +643,10 @@ mod tests { Some(2), Some(16), )]); + // cast the type of c1 to decimal(28,3) + let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. @@ -687,9 +704,6 @@ mod tests { // the type of parquet is decimal(18,2) let schema = Schema::new(vec![Field::new("c1", DataType::Decimal128(18, 2), false)]); - // cast the type of c1 to decimal(28,3) - let left = cast(col("c1"), DataType::Decimal128(28, 3)); - let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); let schema_descr = get_test_schema_descr(vec![( "c1", PhysicalType::BYTE_ARRAY, @@ -701,6 +715,10 @@ mod tests { Some(2), Some(16), )]); + // cast the type of c1 to decimal(28,3) + let left = cast(col("c1"), DataType::Decimal128(28, 3)); + let expr = left.eq(lit(ScalarValue::Decimal128(Some(100000), 28, 3))); + let expr = logical2physical(&expr, &schema); let pruning_predicate = PruningPredicate::try_new(expr, Arc::new(schema)).unwrap(); // we must use the big-endian when encode the i128 to bytes or vec[u8]. @@ -821,4 +839,10 @@ mod tests { let metrics = Arc::new(ExecutionPlanMetricsSet::new()); ParquetFileMetrics::new(0, "file.parquet", &metrics) } + + fn logical2physical(expr: &Expr, schema: &Schema) -> Arc { + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + create_physical_expr(expr, &df_schema, schema, &execution_props).unwrap() + } } diff --git a/datafusion/core/tests/parquet/page_pruning.rs b/datafusion/core/tests/parquet/page_pruning.rs index 61e74e80d047..baf9d2d36a17 100644 --- a/datafusion/core/tests/parquet/page_pruning.rs +++ b/datafusion/core/tests/parquet/page_pruning.rs @@ -25,8 +25,10 @@ use datafusion::execution::context::SessionState; use datafusion::physical_plan::file_format::{FileScanConfig, ParquetExec}; use datafusion::physical_plan::ExecutionPlan; use datafusion::prelude::SessionContext; -use datafusion_common::{ScalarValue, Statistics}; +use datafusion_common::{ScalarValue, Statistics, ToDFSchema}; use datafusion_expr::{col, lit, Expr}; +use datafusion_physical_expr::create_physical_expr; +use datafusion_physical_expr::execution_props::ExecutionProps; use object_store::path::Path; use object_store::ObjectMeta; use tokio_stream::StreamExt; @@ -58,6 +60,11 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { extensions: None, }; + let df_schema = schema.clone().to_dfschema().unwrap(); + let execution_props = ExecutionProps::new(); + let predicate = + create_physical_expr(&filter, &df_schema, &schema, &execution_props).unwrap(); + let parquet_exec = ParquetExec::new( FileScanConfig { object_store_url, @@ -71,7 +78,7 @@ async fn get_parquet_exec(state: &SessionState, filter: Expr) -> ParquetExec { output_ordering: None, infinite_source: false, }, - Some(filter), + Some(predicate), None, ); parquet_exec.with_enable_page_index(true) diff --git a/datafusion/core/tests/row.rs b/datafusion/core/tests/row.rs index c04deb92b92a..5eeb237e187e 100644 --- a/datafusion/core/tests/row.rs +++ b/datafusion/core/tests/row.rs @@ -115,7 +115,7 @@ async fn get_exec( output_ordering: None, infinite_source: false, }, - &[], + None, ) .await?; Ok(exec) diff --git a/datafusion/physical-expr/src/lib.rs b/datafusion/physical-expr/src/lib.rs index c9658a048ca8..7a2ea6872fa7 100644 --- a/datafusion/physical-expr/src/lib.rs +++ b/datafusion/physical-expr/src/lib.rs @@ -49,7 +49,7 @@ pub use aggregate::AggregateExpr; pub use datafusion_common::from_slice; pub use equivalence::EquivalenceProperties; pub use equivalence::EquivalentClass; -pub use physical_expr::{AnalysisContext, ExprBoundaries, PhysicalExpr}; +pub use physical_expr::{AnalysisContext, ExprBoundaries, PhysicalExpr, PhysicalExprRef}; pub use planner::create_physical_expr; pub use scalar_function::ScalarFunctionExpr; pub use sort_expr::PhysicalSortExpr; diff --git a/datafusion/physical-expr/src/physical_expr.rs b/datafusion/physical-expr/src/physical_expr.rs index 459ce8cd7b15..f4e9593c8264 100644 --- a/datafusion/physical-expr/src/physical_expr.rs +++ b/datafusion/physical-expr/src/physical_expr.rs @@ -83,6 +83,9 @@ pub trait PhysicalExpr: Send + Sync + Display + Debug + PartialEq { } } +/// Shared [`PhysicalExpr`]. +pub type PhysicalExprRef = Arc; + /// The shared context used during the analysis of an expression. Includes /// the boundaries for all known columns. #[derive(Clone, Debug, PartialEq)] diff --git a/datafusion/physical-expr/src/utils.rs b/datafusion/physical-expr/src/utils.rs index d6d5054ffef0..612d0e0b8ea0 100644 --- a/datafusion/physical-expr/src/utils.rs +++ b/datafusion/physical-expr/src/utils.rs @@ -19,14 +19,18 @@ use crate::equivalence::EquivalentClass; use crate::expressions::BinaryExpr; use crate::expressions::Column; use crate::expressions::UnKnownColumn; +use crate::rewrite::RewriteRecursion; use crate::rewrite::TreeNodeRewritable; +use crate::rewrite::TreeNodeRewriter; use crate::PhysicalSortExpr; use crate::{EquivalenceProperties, PhysicalExpr}; +use datafusion_common::DataFusionError; use datafusion_expr::Operator; use arrow::datatypes::SchemaRef; use std::collections::HashMap; +use std::collections::HashSet; use std::sync::Arc; /// Compare the two expr lists are equal no matter the order. @@ -235,6 +239,80 @@ pub fn ordering_satisfy_concrete EquivalenceProperties>( } } +/// Extract referenced [`Column`]s within a [`PhysicalExpr`]. +/// +/// This works recursively. +pub fn get_phys_expr_columns(pred: &Arc) -> HashSet { + let mut rewriter = ColumnCollector::default(); + pred.clone() + .transform_using(&mut rewriter) + .expect("never fail"); + rewriter.cols +} + +#[derive(Debug, Default)] +struct ColumnCollector { + cols: HashSet, +} + +impl TreeNodeRewriter> for ColumnCollector { + fn pre_visit( + &mut self, + node: &Arc, + ) -> Result { + if let Some(column) = node.as_any().downcast_ref::() { + self.cols.insert(column.clone()); + } + Ok(RewriteRecursion::Continue) + } + + fn mutate( + &mut self, + expr: Arc, + ) -> Result, DataFusionError> { + Ok(expr) + } +} + +/// Re-assign column indices referenced in predicate according to given schema. +/// +/// This may be helpful when dealing with projections. +pub fn reassign_predicate_columns( + pred: Arc, + schema: &SchemaRef, + ignore_not_found: bool, +) -> Result, DataFusionError> { + let mut rewriter = ColumnAssigner { + schema: schema.clone(), + ignore_not_found, + }; + pred.clone().transform_using(&mut rewriter) +} + +#[derive(Debug)] +struct ColumnAssigner { + schema: SchemaRef, + ignore_not_found: bool, +} + +impl TreeNodeRewriter> for ColumnAssigner { + fn mutate( + &mut self, + expr: Arc, + ) -> Result, DataFusionError> { + if let Some(column) = expr.as_any().downcast_ref::() { + let index = match self.schema.index_of(column.name()) { + Ok(idx) => idx, + Err(_) if self.ignore_not_found => usize::MAX, + Err(e) => return Err(e.into()), + }; + return Ok(Arc::new(Column::new(column.name(), index))); + } + + Ok(expr) + } +} + #[cfg(test)] mod tests { diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index 103fe3face51..d0deb567fad4 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -1165,7 +1165,11 @@ message FileScanExecConf { message ParquetScanExecNode { FileScanExecConf base_conf = 1; - LogicalExprNode pruning_predicate = 2; + + // Was pruning predicate based on a logical expr. + reserved 2; + + PhysicalExprNode predicate = 3; } message CsvScanExecNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 335f6f1c59da..7246b5f2b2cf 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -11832,15 +11832,15 @@ impl serde::Serialize for ParquetScanExecNode { if self.base_conf.is_some() { len += 1; } - if self.pruning_predicate.is_some() { + if self.predicate.is_some() { len += 1; } let mut struct_ser = serializer.serialize_struct("datafusion.ParquetScanExecNode", len)?; if let Some(v) = self.base_conf.as_ref() { struct_ser.serialize_field("baseConf", v)?; } - if let Some(v) = self.pruning_predicate.as_ref() { - struct_ser.serialize_field("pruningPredicate", v)?; + if let Some(v) = self.predicate.as_ref() { + struct_ser.serialize_field("predicate", v)?; } struct_ser.end() } @@ -11854,14 +11854,13 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { const FIELDS: &[&str] = &[ "base_conf", "baseConf", - "pruning_predicate", - "pruningPredicate", + "predicate", ]; #[allow(clippy::enum_variant_names)] enum GeneratedField { BaseConf, - PruningPredicate, + Predicate, } impl<'de> serde::Deserialize<'de> for GeneratedField { fn deserialize(deserializer: D) -> std::result::Result @@ -11884,7 +11883,7 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { { match value { "baseConf" | "base_conf" => Ok(GeneratedField::BaseConf), - "pruningPredicate" | "pruning_predicate" => Ok(GeneratedField::PruningPredicate), + "predicate" => Ok(GeneratedField::Predicate), _ => Err(serde::de::Error::unknown_field(value, FIELDS)), } } @@ -11905,7 +11904,7 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { V: serde::de::MapAccess<'de>, { let mut base_conf__ = None; - let mut pruning_predicate__ = None; + let mut predicate__ = None; while let Some(k) = map.next_key()? { match k { GeneratedField::BaseConf => { @@ -11914,17 +11913,17 @@ impl<'de> serde::Deserialize<'de> for ParquetScanExecNode { } base_conf__ = map.next_value()?; } - GeneratedField::PruningPredicate => { - if pruning_predicate__.is_some() { - return Err(serde::de::Error::duplicate_field("pruningPredicate")); + GeneratedField::Predicate => { + if predicate__.is_some() { + return Err(serde::de::Error::duplicate_field("predicate")); } - pruning_predicate__ = map.next_value()?; + predicate__ = map.next_value()?; } } } Ok(ParquetScanExecNode { base_conf: base_conf__, - pruning_predicate: pruning_predicate__, + predicate: predicate__, }) } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 029380a99ed5..da95fd558fef 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -1680,8 +1680,8 @@ pub struct FileScanExecConf { pub struct ParquetScanExecNode { #[prost(message, optional, tag = "1")] pub base_conf: ::core::option::Option, - #[prost(message, optional, tag = "2")] - pub pruning_predicate: ::core::option::Option, + #[prost(message, optional, tag = "3")] + pub predicate: ::core::option::Option, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] diff --git a/datafusion/proto/src/physical_plan/mod.rs b/datafusion/proto/src/physical_plan/mod.rs index 0898ec416791..8c2ce822f174 100644 --- a/datafusion/proto/src/physical_plan/mod.rs +++ b/datafusion/proto/src/physical_plan/mod.rs @@ -53,7 +53,6 @@ use prost::Message; use crate::common::proto_error; use crate::common::{csv_delimiter_to_string, str_to_byte}; -use crate::logical_plan; use crate::physical_plan::from_proto::{ parse_physical_expr, parse_protobuf_file_scan_config, }; @@ -156,19 +155,22 @@ impl AsExecutionPlan for PhysicalPlanNode { FileCompressionType::UNCOMPRESSED, ))), PhysicalPlanType::ParquetScan(scan) => { + let base_config = parse_protobuf_file_scan_config( + scan.base_conf.as_ref().unwrap(), + registry, + )?; let predicate = scan - .pruning_predicate + .predicate .as_ref() - .map(|expr| logical_plan::from_proto::parse_expr(expr, registry)) + .map(|expr| { + parse_physical_expr( + expr, + registry, + base_config.file_schema.as_ref(), + ) + }) .transpose()?; - Ok(Arc::new(ParquetExec::new( - parse_protobuf_file_scan_config( - scan.base_conf.as_ref().unwrap(), - registry, - )?, - predicate, - None, - ))) + Ok(Arc::new(ParquetExec::new(base_config, predicate, None))) } PhysicalPlanType::AvroScan(scan) => { Ok(Arc::new(AvroExec::new(parse_protobuf_file_scan_config( @@ -956,13 +958,13 @@ impl AsExecutionPlan for PhysicalPlanNode { } else if let Some(exec) = plan.downcast_ref::() { let pruning_expr = exec .pruning_predicate() - .map(|pred| pred.logical_expr().try_into()) + .map(|pred| pred.orig_expr().clone().try_into()) .transpose()?; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::ParquetScan( protobuf::ParquetScanExecNode { base_conf: Some(exec.base_config().try_into()?), - pruning_predicate: pruning_expr, + predicate: pruning_expr, }, )), }) @@ -1218,7 +1220,7 @@ mod roundtrip_tests { use datafusion::physical_expr::expressions::DateTimeIntervalExpr; use datafusion::physical_expr::ScalarFunctionExpr; use datafusion::physical_plan::aggregates::PhysicalGroupBy; - use datafusion::physical_plan::expressions::{like, GetIndexedFieldExpr}; + use datafusion::physical_plan::expressions::{like, BinaryExpr, GetIndexedFieldExpr}; use datafusion::physical_plan::functions; use datafusion::physical_plan::functions::make_scalar_function; use datafusion::physical_plan::projection::ProjectionExec; @@ -1510,7 +1512,11 @@ mod roundtrip_tests { infinite_source: false, }; - let predicate = datafusion::prelude::col("col").eq(datafusion::prelude::lit("1")); + let predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("col", 1)), + Operator::Eq, + lit("1"), + )); roundtrip_test(Arc::new(ParquetExec::new( scan_config, Some(predicate), diff --git a/parquet-test-utils/src/lib.rs b/parquet-test-utils/src/lib.rs index e1b3c5c18a77..59c3024d78b0 100644 --- a/parquet-test-utils/src/lib.rs +++ b/parquet-test-utils/src/lib.rs @@ -158,7 +158,11 @@ impl TestParquetFile { &ExecutionProps::default(), )?; - let parquet_exec = Arc::new(ParquetExec::new(scan_config, Some(filter), None)); + let parquet_exec = Arc::new(ParquetExec::new( + scan_config, + Some(physical_filter_expr.clone()), + None, + )); let exec = Arc::new(FilterExec::try_new(physical_filter_expr, parquet_exec)?);