Skip to content

Commit 26a8000

Browse files
authored
Fix group by aliased expression in LogicalPLanBuilder::aggregate (#8629)
1 parent 78832f1 commit 26a8000

File tree

2 files changed

+73
-21
lines changed

2 files changed

+73
-21
lines changed

datafusion/core/src/dataframe/mod.rs

Lines changed: 34 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1769,8 +1769,8 @@ mod tests {
17691769
let df_results = df.collect().await?;
17701770

17711771
#[rustfmt::skip]
1772-
assert_batches_sorted_eq!(
1773-
[ "+----+",
1772+
assert_batches_sorted_eq!([
1773+
"+----+",
17741774
"| id |",
17751775
"+----+",
17761776
"| 1 |",
@@ -1781,6 +1781,38 @@ mod tests {
17811781
Ok(())
17821782
}
17831783

1784+
#[tokio::test]
1785+
async fn test_aggregate_alias() -> Result<()> {
1786+
let df = test_table().await?;
1787+
1788+
let df = df
1789+
// GROUP BY `c2 + 1`
1790+
.aggregate(vec![col("c2") + lit(1)], vec![])?
1791+
// SELECT `c2 + 1` as c2
1792+
.select(vec![(col("c2") + lit(1)).alias("c2")])?
1793+
// GROUP BY c2 as "c2" (alias in expr is not supported by SQL)
1794+
.aggregate(vec![col("c2").alias("c2")], vec![])?;
1795+
1796+
let df_results = df.collect().await?;
1797+
1798+
#[rustfmt::skip]
1799+
assert_batches_sorted_eq!([
1800+
"+----+",
1801+
"| c2 |",
1802+
"+----+",
1803+
"| 2 |",
1804+
"| 3 |",
1805+
"| 4 |",
1806+
"| 5 |",
1807+
"| 6 |",
1808+
"+----+",
1809+
],
1810+
&df_results
1811+
);
1812+
1813+
Ok(())
1814+
}
1815+
17841816
#[tokio::test]
17851817
async fn test_distinct() -> Result<()> {
17861818
let t = test_table().await?;

datafusion/expr/src/logical_plan/builder.rs

Lines changed: 39 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -904,27 +904,11 @@ impl LogicalPlanBuilder {
904904
group_expr: impl IntoIterator<Item = impl Into<Expr>>,
905905
aggr_expr: impl IntoIterator<Item = impl Into<Expr>>,
906906
) -> Result<Self> {
907-
let mut group_expr = normalize_cols(group_expr, &self.plan)?;
907+
let group_expr = normalize_cols(group_expr, &self.plan)?;
908908
let aggr_expr = normalize_cols(aggr_expr, &self.plan)?;
909909

910-
// Rewrite groupby exprs according to functional dependencies
911-
let group_by_expr_names = group_expr
912-
.iter()
913-
.map(|group_by_expr| group_by_expr.display_name())
914-
.collect::<Result<Vec<_>>>()?;
915-
let schema = self.plan.schema();
916-
if let Some(target_indices) =
917-
get_target_functional_dependencies(schema, &group_by_expr_names)
918-
{
919-
for idx in target_indices {
920-
let field = schema.field(idx);
921-
let expr =
922-
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
923-
if !group_expr.contains(&expr) {
924-
group_expr.push(expr);
925-
}
926-
}
927-
}
910+
let group_expr =
911+
add_group_by_exprs_from_dependencies(group_expr, self.plan.schema())?;
928912
Aggregate::try_new(Arc::new(self.plan), group_expr, aggr_expr)
929913
.map(LogicalPlan::Aggregate)
930914
.map(Self::from)
@@ -1189,6 +1173,42 @@ pub fn build_join_schema(
11891173
schema.with_functional_dependencies(func_dependencies)
11901174
}
11911175

1176+
/// Add additional "synthetic" group by expressions based on functional
1177+
/// dependencies.
1178+
///
1179+
/// For example, if we are grouping on `[c1]`, and we know from
1180+
/// functional dependencies that column `c1` determines `c2`, this function
1181+
/// adds `c2` to the group by list.
1182+
///
1183+
/// This allows MySQL style selects like
1184+
/// `SELECT col FROM t WHERE pk = 5` if col is unique
1185+
fn add_group_by_exprs_from_dependencies(
1186+
mut group_expr: Vec<Expr>,
1187+
schema: &DFSchemaRef,
1188+
) -> Result<Vec<Expr>> {
1189+
// Names of the fields produced by the GROUP BY exprs for example, `GROUP BY
1190+
// c1 + 1` produces an output field named `"c1 + 1"`
1191+
let mut group_by_field_names = group_expr
1192+
.iter()
1193+
.map(|e| e.display_name())
1194+
.collect::<Result<Vec<_>>>()?;
1195+
1196+
if let Some(target_indices) =
1197+
get_target_functional_dependencies(schema, &group_by_field_names)
1198+
{
1199+
for idx in target_indices {
1200+
let field = schema.field(idx);
1201+
let expr =
1202+
Expr::Column(Column::new(field.qualifier().cloned(), field.name()));
1203+
let expr_name = expr.display_name()?;
1204+
if !group_by_field_names.contains(&expr_name) {
1205+
group_by_field_names.push(expr_name);
1206+
group_expr.push(expr);
1207+
}
1208+
}
1209+
}
1210+
Ok(group_expr)
1211+
}
11921212
/// Errors if one or more expressions have equal names.
11931213
pub(crate) fn validate_unique_names<'a>(
11941214
node_name: &str,

0 commit comments

Comments
 (0)