Skip to content

Commit

Permalink
Fix grouping sets behavior when data contains nulls
Browse files Browse the repository at this point in the history
  • Loading branch information
eejbyfeldt committed Sep 21, 2024
1 parent 21ec332 commit f4a220b
Show file tree
Hide file tree
Showing 4 changed files with 159 additions and 78 deletions.
12 changes: 1 addition & 11 deletions datafusion/core/src/physical_planner.rs
Original file line number Diff line number Diff line change
Expand Up @@ -707,10 +707,6 @@ impl DefaultPhysicalPlanner {
physical_input_schema.clone(),
)?);

// update group column indices based on partial aggregate plan evaluation
let final_group: Vec<Arc<dyn PhysicalExpr>> =
initial_aggr.output_group_expr();

let can_repartition = !groups.is_empty()
&& session_state.config().target_partitions() > 1
&& session_state.config().repartition_aggregations();
Expand All @@ -731,13 +727,7 @@ impl DefaultPhysicalPlanner {
AggregateMode::Final
};

let final_grouping_set = PhysicalGroupBy::new_single(
final_group
.iter()
.enumerate()
.map(|(i, expr)| (expr.clone(), groups.expr()[i].1.clone()))
.collect(),
);
let final_grouping_set = initial_aggr.group_expr().as_final();

Arc::new(AggregateExec::try_new(
next_partition_mode,
Expand Down
202 changes: 139 additions & 63 deletions datafusion/physical-plan/src/aggregates/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ impl AggregateMode {
}
}

const INTERNAL_GROUPING_ID: &str = "grouping_id";

/// Represents `GROUP BY` clause in the plan (including the more general GROUPING SET)
/// In the case of a simple `GROUP BY a, b` clause, this will contain the expression [a, b]
/// and a single group [false, false].
Expand Down Expand Up @@ -137,6 +139,10 @@ pub struct PhysicalGroupBy {
/// expression in null_expr. If `groups[i][j]` is true, then the
/// j-th expression in the i-th group is NULL, otherwise it is `expr[j]`.
groups: Vec<Vec<bool>>,
/// The number of expressions that are output by this `PhysicalGroupBy`.
/// Internal expressions like one used to implement grouping sets are not
/// part of the output.
num_output_exprs: usize,
}

impl PhysicalGroupBy {
Expand All @@ -146,10 +152,12 @@ impl PhysicalGroupBy {
null_expr: Vec<(Arc<dyn PhysicalExpr>, String)>,
groups: Vec<Vec<bool>>,
) -> Self {
let num_output_exprs = expr.len() + if !null_expr.is_empty() { 1 } else { 0 };
Self {
expr,
null_expr,
groups,
num_output_exprs,
}
}

Expand All @@ -161,6 +169,7 @@ impl PhysicalGroupBy {
expr,
null_expr: vec![],
groups: vec![vec![false; num_exprs]],
num_output_exprs: num_exprs,
}
}

Expand Down Expand Up @@ -212,11 +221,78 @@ impl PhysicalGroupBy {

/// Return grouping expressions as they occur in the output schema.
pub fn output_exprs(&self) -> Vec<Arc<dyn PhysicalExpr>> {
self.expr
.iter()
.enumerate()
.map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _)
.collect()
let mut output_exprs = Vec::with_capacity(self.num_output_exprs);
output_exprs.extend(
self.expr
.iter()
.take(self.num_output_exprs)
.enumerate()
.map(|(index, (_, name))| Arc::new(Column::new(name, index)) as _),
);
if !self.is_single() {
output_exprs
.push(Arc::new(Column::new(INTERNAL_GROUPING_ID, self.expr.len())) as _);
}
output_exprs
}

pub fn expr_count(&self) -> usize {
self.expr.len() + self.internal_expr_count()
}

pub fn group_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
let mut fields = Vec::with_capacity(self.expr_count());
for ((expr, name), group_expr_nullable) in
self.expr.iter().zip(self.exprs_nullable().into_iter())
{
fields.push(Field::new(
name,
expr.data_type(input_schema)?,
group_expr_nullable || expr.nullable(input_schema)?,
))
}
if !self.is_single() {
fields.push(Field::new(
INTERNAL_GROUPING_ID,
arrow::datatypes::DataType::UInt32,
false,
));
}
Ok(fields)
}

fn output_fields(&self, input_schema: &Schema) -> Result<Vec<Field>> {
let mut fields = self.group_fields(input_schema)?;
fields.truncate(self.num_output_exprs);
Ok(fields)
}

fn internal_expr_count(&self) -> usize {
if self.is_single() {
0
} else {
1
}
}

pub fn as_final(&self) -> PhysicalGroupBy {
let expr: Vec<_> = self
.output_exprs()
.into_iter()
.zip(
self.expr
.iter()
.map(|t| t.1.clone())
.chain(std::iter::once(INTERNAL_GROUPING_ID.to_owned())),
)
.collect();
let num_exprs = expr.len();
Self {
expr,
null_expr: vec![],
groups: vec![vec![false; num_exprs]],
num_output_exprs: num_exprs - self.internal_expr_count(),
}
}
}

Expand Down Expand Up @@ -320,13 +396,7 @@ impl AggregateExec {
input: Arc<dyn ExecutionPlan>,
input_schema: SchemaRef,
) -> Result<Self> {
let schema = create_schema(
&input.schema(),
&group_by.expr,
&aggr_expr,
group_by.exprs_nullable(),
mode,
)?;
let schema = create_schema(&input.schema(), &group_by, &aggr_expr, mode)?;

let schema = Arc::new(schema);
AggregateExec::try_new_with_schema(
Expand Down Expand Up @@ -786,22 +856,12 @@ impl ExecutionPlan for AggregateExec {

fn create_schema(
input_schema: &Schema,
group_expr: &[(Arc<dyn PhysicalExpr>, String)],
group_by: &PhysicalGroupBy,
aggr_expr: &[AggregateFunctionExpr],
group_expr_nullable: Vec<bool>,
mode: AggregateMode,
) -> Result<Schema> {
let mut fields = Vec::with_capacity(group_expr.len() + aggr_expr.len());
for (index, (expr, name)) in group_expr.iter().enumerate() {
fields.push(Field::new(
name,
expr.data_type(input_schema)?,
// In cases where we have multiple grouping sets, we will use NULL expressions in
// order to align the grouping sets. So the field must be nullable even if the underlying
// schema field is not.
group_expr_nullable[index] || expr.nullable(input_schema)?,
))
}
let mut fields = Vec::with_capacity(group_by.num_output_exprs + aggr_expr.len());
fields.extend(group_by.output_fields(input_schema)?);

match mode {
AggregateMode::Partial => {
Expand All @@ -824,9 +884,8 @@ fn create_schema(
Ok(Schema::new(fields))
}

fn group_schema(schema: &Schema, group_count: usize) -> SchemaRef {
let group_fields = schema.fields()[0..group_count].to_vec();
Arc::new(Schema::new(group_fields))
fn group_schema(input_schema: &Schema, group_by: &PhysicalGroupBy) -> Result<SchemaRef> {
Ok(Arc::new(Schema::new(group_by.group_fields(input_schema)?)))
}

/// Determines the lexical ordering requirement for an aggregate expression.
Expand Down Expand Up @@ -1133,6 +1192,23 @@ fn evaluate_optional(
.collect()
}

fn group_id_array(group: &[bool], batch: &RecordBatch) -> Result<ArrayRef> {
if group.len() > 32 {
return not_impl_err!(
"Grouping sets with more than 32 columns are not supported"
);
}
let group_id = group.iter().fold(0u32, |acc, &is_null| {
(acc << 1) | if is_null { 1 } else { 0 }
});
Ok(Arc::new(arrow::array::UInt32Array::from(vec![
group_id;
batch
.num_rows(
)
])))
}

/// Evaluate a group by expression against a `RecordBatch`
///
/// Arguments:
Expand Down Expand Up @@ -1165,23 +1241,24 @@ pub(crate) fn evaluate_group_by(
})
.collect::<Result<Vec<_>>>()?;

Ok(group_by
group_by
.groups
.iter()
.map(|group| {
group
.iter()
.enumerate()
.map(|(idx, is_null)| {
if *is_null {
Arc::clone(&null_exprs[idx])
} else {
Arc::clone(&exprs[idx])
}
})
.collect()
let mut group_values = Vec::with_capacity(group_by.expr_count());
group_values.extend(group.iter().enumerate().map(|(idx, is_null)| {
if *is_null {
Arc::clone(&null_exprs[idx])
} else {
Arc::clone(&exprs[idx])
}
}));
if !group_by.is_single() {
group_values.push(group_id_array(group, batch)?);
}
Ok(group_values)
})
.collect())
.collect()
}

#[cfg(test)]
Expand Down Expand Up @@ -1336,21 +1413,21 @@ mod tests {
) -> Result<()> {
let input_schema = input.schema();

let grouping_set = PhysicalGroupBy {
expr: vec![
let grouping_set = PhysicalGroupBy::new(
vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
null_expr: vec![
vec![
(lit(ScalarValue::UInt32(None)), "a".to_string()),
(lit(ScalarValue::Float64(None)), "b".to_string()),
],
groups: vec![
vec![
vec![false, true], // (a, NULL)
vec![true, false], // (NULL, b)
vec![false, false], // (a,b)
],
};
);

let aggregates = vec![AggregateExprBuilder::new(count_udaf(), vec![lit(1i8)])
.schema(Arc::clone(&input_schema))
Expand Down Expand Up @@ -1488,11 +1565,11 @@ mod tests {
async fn check_aggregates(input: Arc<dyn ExecutionPlan>, spill: bool) -> Result<()> {
let input_schema = input.schema();

let grouping_set = PhysicalGroupBy {
expr: vec![(col("a", &input_schema)?, "a".to_string())],
null_expr: vec![],
groups: vec![vec![false]],
};
let grouping_set = PhysicalGroupBy::new(
vec![(col("a", &input_schema)?, "a".to_string())],
vec![],
vec![vec![false]],
);

let aggregates: Vec<AggregateFunctionExpr> =
vec![
Expand Down Expand Up @@ -1810,11 +1887,11 @@ mod tests {
let task_ctx = Arc::new(task_ctx);

let groups_none = PhysicalGroupBy::default();
let groups_some = PhysicalGroupBy {
expr: vec![(col("a", &input_schema)?, "a".to_string())],
null_expr: vec![],
groups: vec![vec![false]],
};
let groups_some = PhysicalGroupBy::new(
vec![(col("a", &input_schema)?, "a".to_string())],
vec![],
vec![vec![false]],
);

// something that allocates within the aggregator
let aggregates_v0: Vec<AggregateFunctionExpr> =
Expand Down Expand Up @@ -2502,25 +2579,24 @@ mod tests {
.build()?,
];

let grouping_set = PhysicalGroupBy {
expr: vec![
let grouping_set = PhysicalGroupBy::new(
vec![
(col("a", &input_schema)?, "a".to_string()),
(col("b", &input_schema)?, "b".to_string()),
],
null_expr: vec![
vec![
(lit(ScalarValue::Float32(None)), "a".to_string()),
(lit(ScalarValue::Float32(None)), "b".to_string()),
],
groups: vec![
vec![
vec![false, true], // (a, NULL)
vec![false, false], // (a,b)
],
};
);
let aggr_schema = create_schema(
&input_schema,
&grouping_set.expr,
&grouping_set,
&aggr_expr,
grouping_set.exprs_nullable(),
AggregateMode::Final,
)?;
let expected_schema = Schema::new(vec![
Expand Down
9 changes: 6 additions & 3 deletions datafusion/physical-plan/src/aggregates/row_hash.rs
Original file line number Diff line number Diff line change
Expand Up @@ -456,13 +456,13 @@ impl GroupedHashAggregateStream {
let aggregate_arguments = aggregates::aggregate_expressions(
&agg.aggr_expr,
&agg.mode,
agg_group_by.expr.len(),
agg_group_by.expr_count(),
)?;
// arguments for aggregating spilled data is the same as the one for final aggregation
let merging_aggregate_arguments = aggregates::aggregate_expressions(
&agg.aggr_expr,
&AggregateMode::Final,
agg_group_by.expr.len(),
agg_group_by.expr_count(),
)?;

let filter_expressions = match agg.mode {
Expand All @@ -480,7 +480,7 @@ impl GroupedHashAggregateStream {
.map(create_group_accumulator)
.collect::<Result<_>>()?;

let group_schema = group_schema(&agg_schema, agg_group_by.expr.len());
let group_schema = group_schema(&agg.input().schema(), &agg_group_by)?;
let spill_expr = group_schema
.fields
.into_iter()
Expand Down Expand Up @@ -851,6 +851,9 @@ impl GroupedHashAggregateStream {
}

let mut output = self.group_values.emit(emit_to)?;
if !spilling {
output.truncate(self.group_by.num_output_exprs);
}
if let EmitTo::First(n) = emit_to {
self.group_ordering.remove_groups(n);
}
Expand Down
Loading

0 comments on commit f4a220b

Please sign in to comment.