Skip to content

Commit ed92673

Browse files
Dandandanalamb
andauthored
Implement hash partitioned aggregation (#320)
* Implement hash partitioned aggregation * Ballista * Make configurable and use configured concurrency * WIP * Add some hash types * Fmt * Disable repartition aggregations in ballista * fmt * Clippy, ballista * Fix test * Revert test ode * Update datafusion/src/physical_plan/hash_aggregate.rs Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org> * Add info about required child partitioning * Add test * Test fix * Set concurrency Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 1c50371 commit ed92673

File tree

15 files changed

+229
-53
lines changed

15 files changed

+229
-53
lines changed

ballista/rust/core/proto/ballista.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,6 +396,7 @@ message ProjectionExecNode {
396396
enum AggregateMode {
397397
PARTIAL = 0;
398398
FINAL = 1;
399+
FINAL_PARTITIONED = 2;
399400
}
400401

401402
message HashAggregateExecNode {

ballista/rust/core/src/serde/physical_plan/from_proto.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,9 @@ impl TryInto<Arc<dyn ExecutionPlan>> for &protobuf::PhysicalPlanNode {
201201
let agg_mode: AggregateMode = match mode {
202202
protobuf::AggregateMode::Partial => AggregateMode::Partial,
203203
protobuf::AggregateMode::Final => AggregateMode::Final,
204+
protobuf::AggregateMode::FinalPartitioned => {
205+
AggregateMode::FinalPartitioned
206+
}
204207
};
205208

206209
let group = hash_agg

ballista/rust/core/src/serde/physical_plan/to_proto.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ impl TryInto<protobuf::PhysicalPlanNode> for Arc<dyn ExecutionPlan> {
172172
let agg_mode = match exec.mode() {
173173
AggregateMode::Partial => protobuf::AggregateMode::Partial,
174174
AggregateMode::Final => protobuf::AggregateMode::Final,
175+
AggregateMode::FinalPartitioned => {
176+
protobuf::AggregateMode::FinalPartitioned
177+
}
175178
};
176179
let input_schema = exec.input_schema();
177180
let input: protobuf::PhysicalPlanNode = exec.input().to_owned().try_into()?;

ballista/rust/core/src/utils.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,7 @@ pub fn create_datafusion_context() -> ExecutionContext {
322322
let config = ExecutionConfig::new()
323323
.with_concurrency(1)
324324
.with_repartition_joins(false)
325+
.with_repartition_aggregations(false)
325326
.with_physical_optimizer_rules(rules);
326327
ExecutionContext::with_config(config)
327328
}

ballista/rust/scheduler/src/planner.rs

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ impl DistributedPlanner {
128128
//TODO should insert query stages in more generic way based on partitioning metadata
129129
// and not specifically for this operator
130130
match agg.mode() {
131-
AggregateMode::Final => {
131+
AggregateMode::Final | AggregateMode::FinalPartitioned => {
132132
let mut new_children: Vec<Arc<dyn ExecutionPlan>> = vec![];
133133
for child in &children {
134134
let new_stage = create_query_stage(
@@ -237,10 +237,9 @@ mod test {
237237
use ballista_core::serde::protobuf;
238238
use ballista_core::utils::format_plan;
239239
use datafusion::physical_plan::hash_aggregate::HashAggregateExec;
240-
use datafusion::physical_plan::merge::MergeExec;
241-
use datafusion::physical_plan::projection::ProjectionExec;
242240
use datafusion::physical_plan::sort::SortExec;
243241
use datafusion::physical_plan::ExecutionPlan;
242+
use datafusion::physical_plan::{merge::MergeExec, projection::ProjectionExec};
244243
use std::convert::TryInto;
245244
use std::sync::Arc;
246245
use uuid::Uuid;
@@ -278,11 +277,9 @@ mod test {
278277
QueryStageExec: job=f011432e-e424-4016-915d-e3d8b84f6dbd, stage=1
279278
HashAggregateExec: groupBy=["l_returnflag"], aggrExpr=["SUM(l_extendedprice Multiply Int64(1)) [\"l_extendedprice * CAST(1 AS Float64)\"]"]
280279
CsvExec: testdata/lineitem; partitions=2
281-
282280
QueryStageExec: job=f011432e-e424-4016-915d-e3d8b84f6dbd, stage=2
283281
MergeExec
284282
UnresolvedShuffleExec: stages=[1]
285-
286283
QueryStageExec: job=f011432e-e424-4016-915d-e3d8b84f6dbd, stage=3
287284
SortExec { input: ProjectionExec { expr: [(Column { name: "l_returnflag" }, "l_returnflag"), (Column { name: "SUM(l_ext
288285
ProjectionExec { expr: [(Column { name: "l_returnflag" }, "l_returnflag"), (Column { name: "SUM(l_extendedprice Multip

ballista/rust/scheduler/src/test_utils.rs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,12 @@ pub const TPCH_TABLES: &[&str] = &[
3333
pub fn datafusion_test_context(path: &str) -> Result<ExecutionContext> {
3434
// remove Repartition rule because that isn't supported yet
3535
let rules: Vec<Arc<dyn PhysicalOptimizerRule + Send + Sync>> = vec![
36-
Arc::new(CoalesceBatches::new()),
3736
Arc::new(AddMergeExec::new()),
37+
Arc::new(CoalesceBatches::new()),
3838
];
39-
let config = ExecutionConfig::new().with_physical_optimizer_rules(rules);
39+
let config = ExecutionConfig::new()
40+
.with_physical_optimizer_rules(rules)
41+
.with_repartition_aggregations(false);
4042
let mut ctx = ExecutionContext::with_config(config);
4143

4244
for table in TPCH_TABLES {

datafusion/src/execution/context.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -636,6 +636,9 @@ pub struct ExecutionConfig {
636636
/// Should DataFusion repartition data using the join keys to execute joins in parallel
637637
/// using the provided `concurrency` level
638638
pub repartition_joins: bool,
639+
/// Should DataFusion repartition data using the aggregate keys to execute aggregates in parallel
640+
/// using the provided `concurrency` level
641+
pub repartition_aggregations: bool,
639642
}
640643

641644
impl ExecutionConfig {
@@ -663,6 +666,7 @@ impl ExecutionConfig {
663666
create_default_catalog_and_schema: true,
664667
information_schema: false,
665668
repartition_joins: true,
669+
repartition_aggregations: true,
666670
}
667671
}
668672

@@ -746,6 +750,11 @@ impl ExecutionConfig {
746750
self.repartition_joins = enabled;
747751
self
748752
}
753+
/// Enables or disables the use of repartitioning for aggregations to improve parallelism
754+
pub fn with_repartition_aggregations(mut self, enabled: bool) -> Self {
755+
self.repartition_aggregations = enabled;
756+
self
757+
}
749758
}
750759

751760
/// Holds per-execution properties and data (such as starting timestamps, etc).
@@ -1351,7 +1360,6 @@ mod tests {
13511360
#[tokio::test]
13521361
async fn aggregate_grouped() -> Result<()> {
13531362
let results = execute("SELECT c1, SUM(c2) FROM test GROUP BY c1", 4).await?;
1354-
assert_eq!(results.len(), 1);
13551363

13561364
let expected = vec![
13571365
"+----+---------+",
@@ -1371,7 +1379,6 @@ mod tests {
13711379
#[tokio::test]
13721380
async fn aggregate_grouped_avg() -> Result<()> {
13731381
let results = execute("SELECT c1, AVG(c2) FROM test GROUP BY c1", 4).await?;
1374-
assert_eq!(results.len(), 1);
13751382

13761383
let expected = vec![
13771384
"+----+---------+",
@@ -1392,7 +1399,6 @@ mod tests {
13921399
async fn boolean_literal() -> Result<()> {
13931400
let results =
13941401
execute("SELECT c1, c3 FROM test WHERE c1 > 2 AND c3 = true", 4).await?;
1395-
assert_eq!(results.len(), 1);
13961402

13971403
let expected = vec![
13981404
"+----+------+",
@@ -1414,7 +1420,6 @@ mod tests {
14141420
async fn aggregate_grouped_empty() -> Result<()> {
14151421
let results =
14161422
execute("SELECT c1, AVG(c2) FROM test WHERE c1 = 123 GROUP BY c1", 4).await?;
1417-
assert_eq!(results.len(), 1);
14181423

14191424
let expected = vec!["++", "||", "++", "++"];
14201425
assert_batches_sorted_eq!(expected, &results);
@@ -1425,7 +1430,6 @@ mod tests {
14251430
#[tokio::test]
14261431
async fn aggregate_grouped_max() -> Result<()> {
14271432
let results = execute("SELECT c1, MAX(c2) FROM test GROUP BY c1", 4).await?;
1428-
assert_eq!(results.len(), 1);
14291433

14301434
let expected = vec![
14311435
"+----+---------+",
@@ -1445,7 +1449,6 @@ mod tests {
14451449
#[tokio::test]
14461450
async fn aggregate_grouped_min() -> Result<()> {
14471451
let results = execute("SELECT c1, MIN(c2) FROM test GROUP BY c1", 4).await?;
1448-
assert_eq!(results.len(), 1);
14491452

14501453
let expected = vec![
14511454
"+----+---------+",
@@ -1629,7 +1632,6 @@ mod tests {
16291632
#[tokio::test]
16301633
async fn count_aggregated() -> Result<()> {
16311634
let results = execute("SELECT c1, COUNT(c2) FROM test GROUP BY c1", 4).await?;
1632-
assert_eq!(results.len(), 1);
16331635

16341636
let expected = vec![
16351637
"+----+-----------+",
@@ -1681,7 +1683,6 @@ mod tests {
16811683
&mut ctx,
16821684
"SELECT date_trunc('week', t1) as week, SUM(c2) FROM test GROUP BY date_trunc('week', t1)",
16831685
).await?;
1684-
assert_eq!(results.len(), 1);
16851686

16861687
let expected = vec![
16871688
"+---------------------+---------+",
@@ -1925,7 +1926,6 @@ mod tests {
19251926
];
19261927

19271928
let results = run_count_distinct_integers_aggregated_scenario(partitions).await?;
1928-
assert_eq!(results.len(), 1);
19291929

19301930
let expected = vec![
19311931
"+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+",
@@ -1952,7 +1952,6 @@ mod tests {
19521952
];
19531953

19541954
let results = run_count_distinct_integers_aggregated_scenario(partitions).await?;
1955-
assert_eq!(results.len(), 1);
19561955

19571956
let expected = vec![
19581957
"+---------+-----------------+------------------------+-------------------------+-------------------------+-------------------------+-------------------------+--------------------------+--------------------------+--------------------------+",

datafusion/src/physical_optimizer/merge_exec.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ impl PhysicalOptimizerRule for AddMergeExec {
5252
.collect::<Result<Vec<_>>>()?;
5353
match plan.required_child_distribution() {
5454
Distribution::UnspecifiedDistribution => plan.with_new_children(children),
55+
Distribution::HashPartitioned(_) => plan.with_new_children(children),
5556
Distribution::SinglePartition => plan.with_new_children(
5657
children
5758
.iter()

datafusion/src/physical_optimizer/repartition.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,10 @@ fn optimize_concurrency(
5252
.map(|child| {
5353
optimize_concurrency(
5454
concurrency,
55-
plan.required_child_distribution() == Distribution::SinglePartition,
55+
matches!(
56+
plan.required_child_distribution(),
57+
Distribution::SinglePartition
58+
),
5659
child.clone(),
5760
)
5861
})

datafusion/src/physical_plan/hash_aggregate.rs

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,13 @@ pub enum AggregateMode {
7878
Partial,
7979
/// Final aggregate that produces a single partition of output
8080
Final,
81+
/// Final aggregate that works on pre-partitioned data.
82+
///
83+
/// This requires the invariant that all rows with a particular
84+
/// grouping key are in the same partitions, such as is the case
85+
/// with Hash repartitioning on the group keys. If a group key is
86+
/// duplicated, duplicate groups would be produced
87+
FinalPartitioned,
8188
}
8289

8390
/// Hash aggregate execution plan
@@ -123,7 +130,7 @@ fn create_schema(
123130
fields.extend(expr.state_fields()?.iter().cloned())
124131
}
125132
}
126-
AggregateMode::Final => {
133+
AggregateMode::Final | AggregateMode::FinalPartitioned => {
127134
// in final mode, the field with the final result of the accumulator
128135
for expr in aggr_expr {
129136
fields.push(expr.field()?)
@@ -204,6 +211,9 @@ impl ExecutionPlan for HashAggregateExec {
204211
fn required_child_distribution(&self) -> Distribution {
205212
match &self.mode {
206213
AggregateMode::Partial => Distribution::UnspecifiedDistribution,
214+
AggregateMode::FinalPartitioned => Distribution::HashPartitioned(
215+
self.group_expr.iter().map(|x| x.0.clone()).collect(),
216+
),
207217
AggregateMode::Final => Distribution::SinglePartition,
208218
}
209219
}
@@ -454,7 +464,7 @@ fn group_aggregate_batch(
454464
})
455465
.try_for_each(|(accumulator, values)| match mode {
456466
AggregateMode::Partial => accumulator.update_batch(&values),
457-
AggregateMode::Final => {
467+
AggregateMode::FinalPartitioned | AggregateMode::Final => {
458468
// note: the aggregation here is over states, not values, thus the merge
459469
accumulator.merge_batch(&values)
460470
}
@@ -807,7 +817,7 @@ fn aggregate_expressions(
807817
Ok(aggr_expr.iter().map(|agg| agg.expressions()).collect())
808818
}
809819
// in this mode, we build the merge expressions of the aggregation
810-
AggregateMode::Final => Ok(aggr_expr
820+
AggregateMode::Final | AggregateMode::FinalPartitioned => Ok(aggr_expr
811821
.iter()
812822
.map(|agg| merge_expressions(agg))
813823
.collect::<Result<Vec<_>>>()?),
@@ -901,7 +911,9 @@ fn aggregate_batch(
901911
// 1.3
902912
match mode {
903913
AggregateMode::Partial => accum.update_batch(values),
904-
AggregateMode::Final => accum.merge_batch(values),
914+
AggregateMode::Final | AggregateMode::FinalPartitioned => {
915+
accum.merge_batch(values)
916+
}
905917
}
906918
})
907919
}
@@ -1074,7 +1086,7 @@ fn finalize_aggregation(
10741086
.collect::<Result<Vec<_>>>()?;
10751087
Ok(a.iter().flatten().cloned().collect::<Vec<_>>())
10761088
}
1077-
AggregateMode::Final => {
1089+
AggregateMode::Final | AggregateMode::FinalPartitioned => {
10781090
// merge the state to the final value
10791091
accumulators
10801092
.iter()

0 commit comments

Comments
 (0)