From 5a28b029631569d103d7a3f0c51c3252de03b8c1 Mon Sep 17 00:00:00 2001 From: Dan King Date: Thu, 21 Nov 2024 16:41:06 -0500 Subject: [PATCH] feat: support Identity in pruner (#1441) --- vortex-expr/src/identity.rs | 9 +- vortex-file/src/pruning.rs | 253 +++++++++++++++++++++++++++--------- 2 files changed, 203 insertions(+), 59 deletions(-) diff --git a/vortex-expr/src/identity.rs b/vortex-expr/src/identity.rs index 3d9a89b247..e0be36ea6a 100644 --- a/vortex-expr/src/identity.rs +++ b/vortex-expr/src/identity.rs @@ -1,14 +1,21 @@ use std::any::Any; use std::fmt::Display; +use std::sync::Arc; use vortex_array::ArrayData; use vortex_error::VortexResult; -use crate::{unbox_any, VortexExpr}; +use crate::{unbox_any, ExprRef, VortexExpr}; #[derive(Debug, Eq, PartialEq)] pub struct Identity; +impl Identity { + pub fn new_expr() -> ExprRef { + Arc::new(Identity) + } +} + impl Display for Identity { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { write!(f, "[]") diff --git a/vortex-file/src/pruning.rs b/vortex-file/src/pruning.rs index b022149331..b67c3c53e2 100644 --- a/vortex-file/src/pruning.rs +++ b/vortex-file/src/pruning.rs @@ -12,7 +12,7 @@ use vortex_array::ArrayData; use vortex_dtype::field::Field; use vortex_dtype::Nullability; use vortex_error::{VortexExpect as _, VortexResult}; -use vortex_expr::{BinaryExpr, Column, ExprRef, Literal, Not, Operator}; +use vortex_expr::{BinaryExpr, Column, ExprRef, Identity, Literal, Not, Operator}; use vortex_scalar::Scalar; use crate::RowFilter; @@ -70,7 +70,7 @@ impl Relation { #[derive(Debug, Clone)] pub struct PruningPredicate { expr: ExprRef, - required_stats: Relation, + required_stats: Relation, } impl Display for PruningPredicate { @@ -116,7 +116,7 @@ impl PruningPredicate { &self.expr } - pub fn required_stats(&self) -> &HashMap> { + pub fn required_stats(&self) -> &HashMap> { &self.required_stats.map } @@ -136,7 +136,7 @@ impl PruningPredicate { let required_stats = self .required_stats() .iter() - .flat_map(|(key, value)| value.iter().map(|stat| stat_column_name_string(key, *stat))) + .flat_map(|(key, value)| value.iter().map(|stat| key.stat_column_name_string(*stat))) .collect::>(); let missing_stats = required_stats.difference(&known_stats).collect::>(); @@ -180,24 +180,40 @@ fn convert_to_pruning_expression(expr: &ExprRef) -> PruningPredicateStats { } if let Some(col) = bexp.lhs().as_any().downcast_ref::() { - return PruningPredicateRewriter::try_new(col.field().clone(), bexp.op(), bexp.rhs()) - .and_then(PruningPredicateRewriter::rewrite) - .unwrap_or_else(not_prunable); + return PruningPredicateRewriter::rewrite_binary_op( + FieldOrIdentity::Field(col.field().clone()), + bexp.op(), + bexp.rhs(), + ); }; if let Some(col) = bexp.rhs().as_any().downcast_ref::() { - return PruningPredicateRewriter::try_new( - col.field().clone(), + return PruningPredicateRewriter::rewrite_binary_op( + FieldOrIdentity::Field(col.field().clone()), bexp.op().swap(), bexp.lhs(), - ) - .and_then(PruningPredicateRewriter::rewrite) - .unwrap_or_else(not_prunable); + ); } + + if bexp.lhs().as_any().downcast_ref::().is_some() { + return PruningPredicateRewriter::rewrite_binary_op( + FieldOrIdentity::Identity, + bexp.op(), + bexp.rhs(), + ); + }; + + if bexp.rhs().as_any().downcast_ref::().is_some() { + return PruningPredicateRewriter::rewrite_binary_op( + FieldOrIdentity::Identity, + bexp.op().swap(), + bexp.lhs(), + ); + }; } if let Some(RowFilter { conjunction }) = expr.as_any().downcast_ref::() { - let (rewritten_conjunction, refses): (Vec, Vec>) = + let (rewritten_conjunction, refses): (Vec, Vec>) = conjunction .iter() .map(convert_to_pruning_expression) @@ -233,21 +249,27 @@ fn convert_column_reference(expr: &ExprRef, invert: bool) -> PruningPredicateSta } struct PruningPredicateRewriter<'a> { - column: Field, + column: FieldOrIdentity, operator: Operator, other_exp: &'a ExprRef, - stats_to_fetch: Relation, + stats_to_fetch: Relation, } -type PruningPredicateStats = (ExprRef, Relation); +type PruningPredicateStats = (ExprRef, Relation); impl<'a> PruningPredicateRewriter<'a> { - pub fn try_new(column: Field, operator: Operator, other_exp: &'a ExprRef) -> Option { + pub fn try_new( + column: FieldOrIdentity, + operator: Operator, + other_exp: &'a ExprRef, + ) -> Option { // TODO(robert): Simplify expression to guarantee that each column is not compared to itself // For majority of cases self column references are likely not prunable - if other_exp.references().contains(&column) { - return None; - } + if let FieldOrIdentity::Field(field) = &column { + if other_exp.references().contains(field) { + return None; + } + }; Some(Self { column, @@ -257,8 +279,18 @@ impl<'a> PruningPredicateRewriter<'a> { }) } + pub fn rewrite_binary_op( + column: FieldOrIdentity, + operator: Operator, + other_exp: &'a ExprRef, + ) -> PruningPredicateStats { + Self::try_new(column, operator, other_exp) + .and_then(Self::rewrite) + .unwrap_or_else(not_prunable) + } + fn add_stat_reference(&mut self, stat: Stat) -> Field { - let new_field = stat_column_name(&self.column, stat); + let new_field = self.column.stat_column_field(stat); self.stats_to_fetch.insert(self.column.clone(), stat); new_field } @@ -327,11 +359,11 @@ impl<'a> PruningPredicateRewriter<'a> { fn replace_column_with_stat( expr: &ExprRef, stat: Stat, - stats_to_fetch: &mut Relation, + stats_to_fetch: &mut Relation, ) -> Option { if let Some(col) = expr.as_any().downcast_ref::() { - let new_field = stat_column_name(col.field(), stat); - stats_to_fetch.insert(col.field().clone(), stat); + let new_field = stat_column_field(col.field(), stat); + stats_to_fetch.insert(FieldOrIdentity::Field(col.field().clone()), stat); return Some(Column::new_expr(new_field)); } @@ -356,7 +388,13 @@ fn replace_column_with_stat( None } -pub(crate) fn stat_column_name(field: &Field, stat: Stat) -> Field { +#[derive(Debug, Clone, Hash, PartialEq, Eq)] +pub enum FieldOrIdentity { + Field(Field), + Identity, +} + +pub(crate) fn stat_column_field(field: &Field, stat: Stat) -> Field { Field::Name(stat_column_name_string(field, stat)) } @@ -367,15 +405,48 @@ pub(crate) fn stat_column_name_string(field: &Field, stat: Stat) -> String { } } +impl FieldOrIdentity { + pub(crate) fn stat_column_field(&self, stat: Stat) -> Field { + Field::Name(self.stat_column_name_string(stat)) + } + + pub(crate) fn stat_column_name_string(&self, stat: Stat) -> String { + match self { + FieldOrIdentity::Field(field) => stat_column_name_string(field, stat), + FieldOrIdentity::Identity => stat.to_string(), + } + } +} + +impl Display for FieldOrIdentity { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FieldOrIdentity::Field(field) => write!(f, "{}", field), + FieldOrIdentity::Identity => write!(f, "$[]"), + } + } +} + +impl From for FieldOrIdentity +where + Field: From, +{ + fn from(value: T) -> Self { + FieldOrIdentity::Field(Field::from(value)) + } +} + #[cfg(test)] mod tests { use vortex_array::aliases::hash_map::HashMap; use vortex_array::aliases::hash_set::HashSet; use vortex_array::stats::Stat; use vortex_dtype::field::Field; - use vortex_expr::{BinaryExpr, Column, Literal, Not, Operator}; + use vortex_expr::{BinaryExpr, Column, Identity, Literal, Not, Operator}; - use crate::pruning::{convert_to_pruning_expression, stat_column_name, PruningPredicate}; + use crate::pruning::{ + convert_to_pruning_expression, stat_column_field, FieldOrIdentity, PruningPredicate, + }; #[test] pub fn pruning_equals() { @@ -389,11 +460,14 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(&eq_expr); assert_eq!( refs.into_map(), - HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Min, Stat::Max]))]) + HashMap::from_iter([( + FieldOrIdentity::Field(column.clone()), + HashSet::from_iter([Stat::Min, Stat::Max]) + )]) ); let expected_expr = BinaryExpr::new_expr( BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Min)), + Column::new_expr(stat_column_field(&column, Stat::Min)), Operator::Gt, literal_eq.clone(), ), @@ -401,7 +475,7 @@ mod tests { BinaryExpr::new_expr( literal_eq, Operator::Gt, - Column::new_expr(stat_column_name(&column, Stat::Max)), + Column::new_expr(stat_column_field(&column, Stat::Max)), ), ); assert_eq!(*converted, *expected_expr.as_any()); @@ -421,24 +495,27 @@ mod tests { assert_eq!( refs.into_map(), HashMap::from_iter([ - (column.clone(), HashSet::from_iter([Stat::Min, Stat::Max])), ( - other_col.clone(), + FieldOrIdentity::Field(column.clone()), + HashSet::from_iter([Stat::Min, Stat::Max]) + ), + ( + FieldOrIdentity::Field(other_col.clone()), HashSet::from_iter([Stat::Max, Stat::Min]) ) ]) ); let expected_expr = BinaryExpr::new_expr( BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Min)), + Column::new_expr(stat_column_field(&column, Stat::Min)), Operator::Gt, - Column::new_expr(stat_column_name(&other_col, Stat::Max)), + Column::new_expr(stat_column_field(&other_col, Stat::Max)), ), Operator::Or, BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&other_col, Stat::Min)), + Column::new_expr(stat_column_field(&other_col, Stat::Min)), Operator::Gt, - Column::new_expr(stat_column_name(&column, Stat::Max)), + Column::new_expr(stat_column_field(&column, Stat::Max)), ), ); assert_eq!(*converted, *expected_expr.as_any()); @@ -458,9 +535,12 @@ mod tests { assert_eq!( refs.into_map(), HashMap::from_iter([ - (column.clone(), HashSet::from_iter([Stat::Min, Stat::Max])), ( - other_col.clone(), + FieldOrIdentity::Field(column.clone()), + HashSet::from_iter([Stat::Min, Stat::Max]) + ), + ( + FieldOrIdentity::Field(other_col.clone()), HashSet::from_iter([Stat::Max, Stat::Min]) ) ]) @@ -468,22 +548,22 @@ mod tests { let expected_expr = BinaryExpr::new_expr( BinaryExpr::new_expr( BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Min)), + Column::new_expr(stat_column_field(&column, Stat::Min)), Operator::Eq, - Column::new_expr(stat_column_name(&column, Stat::Max)), + Column::new_expr(stat_column_field(&column, Stat::Max)), ), Operator::And, BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&other_col, Stat::Min)), + Column::new_expr(stat_column_field(&other_col, Stat::Min)), Operator::Eq, - Column::new_expr(stat_column_name(&other_col, Stat::Max)), + Column::new_expr(stat_column_field(&other_col, Stat::Max)), ), ), Operator::And, BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Min)), + Column::new_expr(stat_column_field(&column, Stat::Min)), Operator::Eq, - Column::new_expr(stat_column_name(&other_col, Stat::Min)), + Column::new_expr(stat_column_field(&other_col, Stat::Min)), ), ); @@ -505,14 +585,20 @@ mod tests { assert_eq!( refs.into_map(), HashMap::from_iter([ - (column.clone(), HashSet::from_iter([Stat::Max])), - (other_col.clone(), HashSet::from_iter([Stat::Min])) + ( + FieldOrIdentity::Field(column.clone()), + HashSet::from_iter([Stat::Max]) + ), + ( + FieldOrIdentity::Field(other_col.clone()), + HashSet::from_iter([Stat::Min]) + ) ]) ); let expected_expr = BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Max)), + Column::new_expr(stat_column_field(&column, Stat::Max)), Operator::Lte, - Column::new_expr(stat_column_name(&other_col, Stat::Min)), + Column::new_expr(stat_column_field(&other_col, Stat::Min)), ); assert_eq!(*converted, *expected_expr.as_any()); } @@ -530,10 +616,13 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); assert_eq!( refs.into_map(), - HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Max])),]) + HashMap::from_iter([( + FieldOrIdentity::Field(column.clone()), + HashSet::from_iter([Stat::Max]) + ),]) ); let expected_expr = BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Max)), + Column::new_expr(stat_column_field(&column, Stat::Max)), Operator::Lte, other_col.clone(), ); @@ -555,14 +644,20 @@ mod tests { assert_eq!( refs.into_map(), HashMap::from_iter([ - (column.clone(), HashSet::from_iter([Stat::Min])), - (other_col.clone(), HashSet::from_iter([Stat::Max])) + ( + FieldOrIdentity::Field(column.clone()), + HashSet::from_iter([Stat::Min]) + ), + ( + FieldOrIdentity::Field(other_col.clone()), + HashSet::from_iter([Stat::Max]) + ) ]) ); let expected_expr = BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Min)), + Column::new_expr(stat_column_field(&column, Stat::Min)), Operator::Gte, - Column::new_expr(stat_column_name(&other_col, Stat::Max)), + Column::new_expr(stat_column_field(&other_col, Stat::Max)), ); assert_eq!(*converted, *expected_expr.as_any()); } @@ -580,10 +675,13 @@ mod tests { let (converted, refs) = convert_to_pruning_expression(¬_eq_expr); assert_eq!( refs.into_map(), - HashMap::from_iter([(column.clone(), HashSet::from_iter([Stat::Min]))]) + HashMap::from_iter([( + FieldOrIdentity::Field(column.clone()), + HashSet::from_iter([Stat::Min]) + )]) ); let expected_expr = BinaryExpr::new_expr( - Column::new_expr(stat_column_name(&column, Stat::Min)), + Column::new_expr(stat_column_field(&column, Stat::Min)), Operator::Gte, other_col.clone(), ); @@ -621,7 +719,10 @@ mod tests { BinaryExpr::new_expr(column, Operator::Gt, Literal::new_expr(50.into())), ); - let expected = HashMap::from([(Field::from("a"), HashSet::from([Stat::Min, Stat::Max]))]); + let expected = HashMap::from([( + FieldOrIdentity::from("a"), + HashSet::from([Stat::Min, Stat::Max]), + )]); assert_eq!( PruningPredicate::try_new(&expr).unwrap().required_stats(), @@ -638,11 +739,47 @@ mod tests { BinaryExpr::new_expr(column, Operator::Lt, Literal::new_expr(10.into())), ); - let expected = HashMap::from([(Field::from("a"), HashSet::from([Stat::Min, Stat::Max]))]); + let expected = HashMap::from([( + FieldOrIdentity::from("a"), + HashSet::from([Stat::Min, Stat::Max]), + )]); assert_eq!( PruningPredicate::try_new(&expr).unwrap().required_stats(), &expected ); } + + #[test] + fn pruning_identity() { + let column = Identity::new_expr(); + let expr = BinaryExpr::new_expr( + BinaryExpr::new_expr(column.clone(), Operator::Lt, Literal::new_expr(10.into())), + Operator::Or, + BinaryExpr::new_expr(column, Operator::Gt, Literal::new_expr(50.into())), + ); + + let expected = HashMap::from([( + FieldOrIdentity::Identity, + HashSet::from([Stat::Min, Stat::Max]), + )]); + + let predicate = PruningPredicate::try_new(&expr).unwrap(); + assert_eq!(predicate.required_stats(), &expected); + + let expected_expr = BinaryExpr::new_expr( + BinaryExpr::new_expr( + Column::new_expr(Field::Name("min".to_string())), + Operator::Gte, + Literal::new_expr(10.into()), + ), + Operator::Or, + BinaryExpr::new_expr( + Column::new_expr(Field::Name("max".to_string())), + Operator::Lte, + Literal::new_expr(50.into()), + ), + ); + assert_eq!(*predicate.expr().clone(), *expected_expr.as_any(),) + } }