Skip to content

Commit

Permalink
refactor: unify replace count(*) analyzer by removing it in sql crate (
Browse files Browse the repository at this point in the history
…#6660)

refactor: unify replace count(*) analyzer by removing it in sql crate
fix: CountWildcardRule ignore Expr::Alias
  • Loading branch information
jackwener authored Jun 14, 2023
1 parent b586d4e commit 6194d58
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -565,4 +565,3 @@ SELECT sqrt(column1),sqrt(column2),sqrt(column3),sqrt(column4),sqrt(column5),sqr

statement ok
drop table t

7 changes: 7 additions & 0 deletions datafusion/optimizer/src/analyzer/count_wildcard_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,13 @@ impl TreeNodeRewriter for CountWildcardRewriter {

fn mutate(&mut self, old_expr: Expr) -> Result<Expr> {
let new_expr = match old_expr.clone() {
Expr::Alias(expr, alias) if alias.contains(COUNT_STAR) => Expr::Alias(
expr,
alias.replace(
COUNT_STAR,
count(lit(COUNT_STAR_EXPANSION)).to_string().as_str(),
),
),
Expr::Column(Column { name, relation }) if name.contains(COUNT_STAR) => {
Expr::Column(Column {
name: name.replace(
Expand Down
37 changes: 4 additions & 33 deletions datafusion/sql/src/expr/function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ use crate::planner::{ContextProvider, PlannerContext, SqlToRel};
use datafusion_common::{DFSchema, DataFusionError, Result};
use datafusion_expr::expr::{ScalarFunction, ScalarUDF};
use datafusion_expr::function::suggest_valid_function;
use datafusion_expr::utils::COUNT_STAR_EXPANSION;
use datafusion_expr::window_frame::regularize;
use datafusion_expr::{
expr, window_function, AggregateFunction, BuiltinScalarFunction, Expr, WindowFrame,
Expand Down Expand Up @@ -96,8 +95,7 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
if let Ok(fun) = self.find_window_func(&name) {
let expr = match fun {
WindowFunction::AggregateFunction(aggregate_fun) => {
let (aggregate_fun, args) = self.aggregate_fn_to_expr(
aggregate_fun,
let args = self.function_args_to_expr(
function.args,
schema,
planner_context,
Expand Down Expand Up @@ -135,12 +133,9 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
planner_context,
)?;
let order_by = (!order_by.is_empty()).then_some(order_by);
let (fun, args) = self.aggregate_fn_to_expr(
fun,
function.args,
schema,
planner_context,
)?;
let args =
self.function_args_to_expr(function.args, schema, planner_context)?;

return Ok(Expr::AggregateFunction(expr::AggregateFunction::new(
fun, args, distinct, None, order_by,
)));
Expand Down Expand Up @@ -228,28 +223,4 @@ impl<'a, S: ContextProvider> SqlToRel<'a, S> {
.map(|a| self.sql_fn_arg_to_logical_expr(a, schema, planner_context))
.collect::<Result<Vec<Expr>>>()
}

pub(super) fn aggregate_fn_to_expr(
&self,
fun: AggregateFunction,
args: Vec<FunctionArg>,
schema: &DFSchema,
planner_context: &mut PlannerContext,
) -> Result<(AggregateFunction, Vec<Expr>)> {
let args = match fun {
// Special case rewrite COUNT(*) to COUNT(constant)
AggregateFunction::Count => args
.into_iter()
.map(|a| match a {
FunctionArg::Unnamed(FunctionArgExpr::Wildcard) => {
Ok(Expr::Literal(COUNT_STAR_EXPANSION.clone()))
}
_ => self.sql_fn_arg_to_logical_expr(a, schema, planner_context),
})
.collect::<Result<Vec<Expr>>>()?,
_ => self.function_args_to_expr(args, schema, planner_context)?,
};

Ok((fun, args))
}
}
30 changes: 15 additions & 15 deletions datafusion/sql/tests/integration_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -881,7 +881,7 @@ fn select_aggregate_with_having_referencing_column_not_in_select() {
assert_eq!(
"Plan(\"HAVING clause references non-aggregate values: \
Expression person.first_name could not be resolved from available columns: \
COUNT(UInt8(1))\")",
COUNT(*)\")",
format!("{err:?}")
);
}
Expand Down Expand Up @@ -1084,8 +1084,8 @@ fn select_aggregate_with_group_by_with_having_using_count_star_not_in_select() {
GROUP BY first_name
HAVING MAX(age) > 100 AND COUNT(*) < 50";
let expected = "Projection: person.first_name, MAX(person.age)\
\n Filter: MAX(person.age) > Int64(100) AND COUNT(UInt8(1)) < Int64(50)\
\n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(UInt8(1))]]\
\n Filter: MAX(person.age) > Int64(100) AND COUNT(*) < Int64(50)\
\n Aggregate: groupBy=[[person.first_name]], aggr=[[MAX(person.age), COUNT(*)]]\
\n TableScan: person";
quick_test(sql, expected);
}
Expand Down Expand Up @@ -1665,8 +1665,8 @@ fn select_group_by_columns_not_in_select() {
#[test]
fn select_group_by_count_star() {
let sql = "SELECT state, COUNT(*) FROM person GROUP BY state";
let expected = "Projection: person.state, COUNT(UInt8(1))\
\n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(UInt8(1))]]\
let expected = "Projection: person.state, COUNT(*)\
\n Aggregate: groupBy=[[person.state]], aggr=[[COUNT(*)]]\
\n TableScan: person";

quick_test(sql, expected);
Expand Down Expand Up @@ -2884,8 +2884,8 @@ fn scalar_subquery_reference_outer_field() {
let expected = "Projection: j1.j1_string, j2.j2_string\
\n Filter: j1.j1_id = j2.j2_id - Int64(1) AND j2.j2_id < (<subquery>)\
\n Subquery:\
\n Projection: COUNT(UInt8(1))\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(UInt8(1))]]\
\n Projection: COUNT(*)\
\n Aggregate: groupBy=[[]], aggr=[[COUNT(*)]]\
\n Filter: outer_ref(j2.j2_id) = j1.j1_id AND j1.j1_id = j3.j3_id\
\n CrossJoin:\
\n TableScan: j1\
Expand Down Expand Up @@ -2983,8 +2983,8 @@ fn cte_unbalanced_number_of_columns() {
fn aggregate_with_rollup() {
let sql =
"SELECT id, state, age, COUNT(*) FROM person GROUP BY id, ROLLUP (state, age)";
let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\
let expected = "Projection: person.id, person.state, person.age, COUNT(*)\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\
\n TableScan: person";
quick_test(sql, expected);
}
Expand All @@ -2993,8 +2993,8 @@ fn aggregate_with_rollup() {
fn aggregate_with_rollup_with_grouping() {
let sql = "SELECT id, state, age, grouping(state), grouping(age), grouping(state) + grouping(age), COUNT(*) \
FROM person GROUP BY id, ROLLUP (state, age)";
let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(UInt8(1))\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(UInt8(1))]]\
let expected = "Projection: person.id, person.state, person.age, GROUPING(person.state), GROUPING(person.age), GROUPING(person.state) + GROUPING(person.age), COUNT(*)\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.state, person.age))]], aggr=[[GROUPING(person.state), GROUPING(person.age), COUNT(*)]]\
\n TableScan: person";
quick_test(sql, expected);
}
Expand Down Expand Up @@ -3025,8 +3025,8 @@ fn rank_partition_grouping() {
fn aggregate_with_cube() {
let sql =
"SELECT id, state, age, COUNT(*) FROM person GROUP BY id, CUBE (state, age)";
let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(UInt8(1))]]\
let expected = "Projection: person.id, person.state, person.age, COUNT(*)\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id), (person.id, person.state), (person.id, person.age), (person.id, person.state, person.age))]], aggr=[[COUNT(*)]]\
\n TableScan: person";
quick_test(sql, expected);
}
Expand All @@ -3042,8 +3042,8 @@ fn round_decimal() {
#[test]
fn aggregate_with_grouping_sets() {
let sql = "SELECT id, state, age, COUNT(*) FROM person GROUP BY id, GROUPING SETS ((state), (state, age), (id, state))";
let expected = "Projection: person.id, person.state, person.age, COUNT(UInt8(1))\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(UInt8(1))]]\
let expected = "Projection: person.id, person.state, person.age, COUNT(*)\
\n Aggregate: groupBy=[[GROUPING SETS ((person.id, person.state), (person.id, person.state, person.age), (person.id, person.id, person.state))]], aggr=[[COUNT(*)]]\
\n TableScan: person";
quick_test(sql, expected);
}
Expand Down

0 comments on commit 6194d58

Please sign in to comment.