Skip to content

Commit

Permalink
Short term way to make AggregateStatistics still work when min/max …
Browse files Browse the repository at this point in the history
…is converted to udaf (#11261)

* impl the short term solution.

* add todos.
  • Loading branch information
Rachelint authored Jul 12, 2024
1 parent d5367f3 commit 02335eb
Showing 1 changed file with 85 additions and 51 deletions.
136 changes: 85 additions & 51 deletions datafusion/core/src/physical_optimizer/aggregate_statistics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,31 +140,29 @@ fn take_optimizable_column_and_table_count(
stats: &Statistics,
) -> Option<(ScalarValue, String)> {
let col_stats = &stats.column_statistics;
if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() {
if let Precision::Exact(num_rows) = stats.num_rows {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
let current_val = &col_stats[col_expr.index()].null_count;
if let &Precision::Exact(val) = current_val {
return Some((
ScalarValue::Int64(Some((num_rows - val) as i64)),
agg_expr.name().to_string(),
));
}
} else if let Some(lit_expr) =
exprs[0].as_any().downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
agg_expr.name().to_string(),
));
}
if is_non_distinct_count(agg_expr) {
if let Precision::Exact(num_rows) = stats.num_rows {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
let current_val = &col_stats[col_expr.index()].null_count;
if let &Precision::Exact(val) = current_val {
return Some((
ScalarValue::Int64(Some((num_rows - val) as i64)),
agg_expr.name().to_string(),
));
}
} else if let Some(lit_expr) =
exprs[0].as_any().downcast_ref::<expressions::Literal>()
{
if lit_expr.value() == &COUNT_STAR_EXPANSION {
return Some((
ScalarValue::Int64(Some(num_rows as i64)),
agg_expr.name().to_string(),
));
}
}
}
Expand All @@ -182,34 +180,30 @@ fn take_optimizable_min(
match *num_rows {
0 => {
// MIN/MAX with 0 rows is always null
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Min>()
{
if is_min(agg_expr) {
if let Ok(min_data_type) =
ScalarValue::try_from(casted_expr.field().unwrap().data_type())
ScalarValue::try_from(agg_expr.field().unwrap().data_type())
{
return Some((min_data_type, casted_expr.name().to_string()));
return Some((min_data_type, agg_expr.name().to_string()));
}
}
}
value if value > 0 => {
let col_stats = &stats.column_statistics;
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Min>()
{
if casted_expr.expressions().len() == 1 {
if is_min(agg_expr) {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
if let Precision::Exact(val) =
&col_stats[col_expr.index()].min_value
{
if !val.is_null() {
return Some((
val.clone(),
casted_expr.name().to_string(),
agg_expr.name().to_string(),
));
}
}
Expand All @@ -232,34 +226,30 @@ fn take_optimizable_max(
match *num_rows {
0 => {
// MIN/MAX with 0 rows is always null
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Max>()
{
if is_max(agg_expr) {
if let Ok(max_data_type) =
ScalarValue::try_from(casted_expr.field().unwrap().data_type())
ScalarValue::try_from(agg_expr.field().unwrap().data_type())
{
return Some((max_data_type, casted_expr.name().to_string()));
return Some((max_data_type, agg_expr.name().to_string()));
}
}
}
value if value > 0 => {
let col_stats = &stats.column_statistics;
if let Some(casted_expr) =
agg_expr.as_any().downcast_ref::<expressions::Max>()
{
if casted_expr.expressions().len() == 1 {
if is_max(agg_expr) {
let exprs = agg_expr.expressions();
if exprs.len() == 1 {
// TODO optimize with exprs other than Column
if let Some(col_expr) = casted_expr.expressions()[0]
.as_any()
.downcast_ref::<expressions::Column>()
if let Some(col_expr) =
exprs[0].as_any().downcast_ref::<expressions::Column>()
{
if let Precision::Exact(val) =
&col_stats[col_expr.index()].max_value
{
if !val.is_null() {
return Some((
val.clone(),
casted_expr.name().to_string(),
agg_expr.name().to_string(),
));
}
}
Expand All @@ -273,6 +263,50 @@ fn take_optimizable_max(
None
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_non_distinct_count(agg_expr: &dyn AggregateExpr) -> bool {
if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "count" && !agg_expr.is_distinct() {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_min(agg_expr: &dyn AggregateExpr) -> bool {
if agg_expr.as_any().is::<expressions::Min>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "min" {
return true;
}
}

false
}

// TODO: Move this check into AggregateUDFImpl
// https://github.com/apache/datafusion/issues/11153
fn is_max(agg_expr: &dyn AggregateExpr) -> bool {
if agg_expr.as_any().is::<expressions::Max>() {
return true;
}

if let Some(agg_expr) = agg_expr.as_any().downcast_ref::<AggregateFunctionExpr>() {
if agg_expr.fun().name() == "max" {
return true;
}
}

false
}

#[cfg(test)]
pub(crate) mod tests {
use super::*;
Expand Down

0 comments on commit 02335eb

Please sign in to comment.