Skip to content

Commit

Permalink
Fix group by aliased expression in LogicalPLanBuilder::aggregate (#8629)
Browse files Browse the repository at this point in the history
  • Loading branch information
alamb authored Dec 26, 2023
1 parent 78832f1 commit 26a8000
Show file tree
Hide file tree
Showing 2 changed files with 73 additions and 21 deletions.
36 changes: 34 additions & 2 deletions datafusion/core/src/dataframe/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1769,8 +1769,8 @@ mod tests {
let df_results = df.collect().await?;

#[rustfmt::skip]
assert_batches_sorted_eq!(
[ "+----+",
assert_batches_sorted_eq!([
"+----+",
"| id |",
"+----+",
"| 1 |",
Expand All @@ -1781,6 +1781,38 @@ mod tests {
Ok(())
}

#[tokio::test]
async fn test_aggregate_alias() -> Result<()> {
let df = test_table().await?;

let df = df
// GROUP BY `c2 + 1`
.aggregate(vec![col("c2") + lit(1)], vec![])?
// SELECT `c2 + 1` as c2
.select(vec![(col("c2") + lit(1)).alias("c2")])?
// GROUP BY c2 as "c2" (alias in expr is not supported by SQL)
.aggregate(vec![col("c2").alias("c2")], vec![])?;

let df_results = df.collect().await?;

#[rustfmt::skip]
assert_batches_sorted_eq!([
"+----+",
"| c2 |",
"+----+",
"| 2 |",
"| 3 |",
"| 4 |",
"| 5 |",
"| 6 |",
"+----+",
],
&df_results
);

Ok(())
}

#[tokio::test]
async fn test_distinct() -> Result<()> {
let t = test_table().await?;
Expand Down
58 changes: 39 additions & 19 deletions datafusion/expr/src/logical_plan/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -904,27 +904,11 @@ impl LogicalPlanBuilder {
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
) -> Result<Self> {
let mut group_expr = normalize_cols(group_expr, &self.plan)?;
let group_expr = normalize_cols(group_expr, &self.plan)?;
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;

// Rewrite groupby exprs according to functional dependencies
let group_by_expr_names = group_expr
.iter()
.map(|group_by_expr| group_by_expr.display_name())
.collect::<Result<Vec<_>>>()?;
let schema = self.plan.schema();
if let Some(target_indices) =
get_target_functional_dependencies(schema, &group_by_expr_names)
{
for idx in target_indices {
let field = schema.field(idx);
let expr =
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
if !group_expr.contains(&expr) {
group_expr.push(expr);
}
}
}
let group_expr =
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr)
.map(LogicalPlan::Aggregate)
.map(Self::from)
Expand Down Expand Up @@ -1189,6 +1173,42 @@ pub fn build_join_schema(
schema.with_functional_dependencies(func_dependencies)
}

/// Add additional "synthetic" group by expressions based on functional
/// dependencies.
///
/// For example, if we are grouping on `[c1]`, and we know from
/// functional dependencies that column `c1` determines `c2`, this function
/// adds `c2` to the group by list.
///
/// This allows MySQL style selects like
/// `SELECT col FROM t WHERE pk = 5` if col is unique
fn add_group_by_exprs_from_dependencies(
mut group_expr: Vec<Expr>,
schema: &DFSchemaRef,
) -> Result<Vec<Expr>> {
// Names of the fields produced by the GROUP BY exprs for example, `GROUP BY
// c1 + 1` produces an output field named `"c1 + 1"`
let mut group_by_field_names = group_expr
.iter()
.map(|e| e.display_name())
.collect::<Result<Vec<_>>>()?;

if let Some(target_indices) =
get_target_functional_dependencies(schema, &group_by_field_names)
{
for idx in target_indices {
let field = schema.field(idx);
let expr =
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
let expr_name = expr.display_name()?;
if !group_by_field_names.contains(&expr_name) {
group_by_field_names.push(expr_name);
group_expr.push(expr);
}
}
}
Ok(group_expr)
}
/// Errors if one or more expressions have equal names.
pub(crate) fn validate_unique_names<'a>(
node_name: &str,
Expand Down

0 comments on commit 26a8000

Please sign in to comment.