Skip to content

Commit 636e5fe

Browse files
committed
fix COUNT(UInt8(1))
1 parent c3319da commit 636e5fe

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

datafusion/core/tests/sql/window.rs

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -686,14 +686,15 @@ async fn test_remove_unnecessary_sort_in_sub_query() -> Result<()> {
686686
" CoalescePartitionsExec",
687687
" AggregateExec: mode=Partial, gby=[], aggr=[COUNT(UInt8(1))]",
688688
" RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=8",
689-
" AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[]",
690-
" CoalesceBatchesExec: target_batch_size=4096",
691-
" RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
692-
" AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[]",
693-
" ProjectionExec: expr=[c1@0 as c1]",
694-
" CoalesceBatchesExec: target_batch_size=4096",
695-
" FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434",
696-
" RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
689+
" ProjectionExec: expr=[c1@0 as c1]",
690+
" AggregateExec: mode=FinalPartitioned, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]",
691+
" CoalesceBatchesExec: target_batch_size=4096",
692+
" RepartitionExec: partitioning=Hash([Column { name: \"c1\", index: 0 }], 8), input_partitions=8",
693+
" AggregateExec: mode=Partial, gby=[c1@0 as c1], aggr=[COUNT(UInt8(1))]",
694+
" ProjectionExec: expr=[c1@0 as c1]",
695+
" CoalesceBatchesExec: target_batch_size=4096",
696+
" FilterExec: c13@1 != C2GT5KVyOPZpgKVl110TyZO0NcJ434",
697+
" RepartitionExec: partitioning=RoundRobinBatch(8), input_partitions=1",
697698
]
698699
};
699700

datafusion/optimizer/src/push_down_projection.rs

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,11 @@ use crate::optimizer::ApplyOrder;
2323
use crate::push_down_filter::replace_cols_by_name;
2424
use crate::{OptimizerConfig, OptimizerRule};
2525
use arrow::error::Result as ArrowResult;
26+
use datafusion_common::ScalarValue::UInt8;
2627
use datafusion_common::{
2728
Column, DFField, DFSchema, DFSchemaRef, DataFusionError, Result, ToDFSchema,
2829
};
30+
use datafusion_expr::expr::AggregateFunction;
2931
use datafusion_expr::utils::exprlist_to_fields;
3032
use datafusion_expr::{
3133
logical_plan::{Aggregate, LogicalPlan, Projection, TableScan, Union},
@@ -262,7 +264,6 @@ impl OptimizerRule for PushDownProjection {
262264
generate_plan!(projection_is_empty, plan, new_alias)
263265
}
264266
LogicalPlan::Aggregate(agg) => {
265-
// remove any aggregate expression that is not required
266267
let mut required_columns = HashSet::new();
267268
exprlist_to_columns(&projection.expr, &mut required_columns)?;
268269
// Gather all columns needed for expressions in this Aggregate
@@ -274,6 +275,21 @@ impl OptimizerRule for PushDownProjection {
274275
}
275276
}
276277

278+
// if new_aggr_expr emtpy and aggr is COUNT(UInt8(1)), push it
279+
if new_aggr_expr.is_empty() && agg.aggr_expr.len() == 1 {
280+
if let Expr::AggregateFunction(AggregateFunction {
281+
fun, args, ..
282+
}) = &agg.aggr_expr[0]
283+
{
284+
if matches!(fun, datafusion_expr::AggregateFunction::Count)
285+
&& args.len() == 1
286+
&& args[0] == Expr::Literal(UInt8(Some(1)))
287+
{
288+
new_aggr_expr.push(agg.aggr_expr[0].clone());
289+
}
290+
}
291+
}
292+
277293
let new_agg = LogicalPlan::Aggregate(Aggregate::try_new(
278294
agg.input.clone(),
279295
agg.group_expr.clone(),

0 commit comments

Comments
 (0)