Skip to content

Commit 00fea4e

Browse files
committed
Refine the size() calculation of accumulator
1 parent c97048d commit 00fea4e

File tree

3 files changed

+74
-38
lines changed

3 files changed

+74
-38
lines changed

datafusion/core/src/physical_plan/aggregates/row_hash.rs

Lines changed: 72 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,7 @@ impl GroupedHashAggregateStream {
318318
..
319319
} = &mut self.row_aggr_state;
320320

321+
let mut accumulator_set_init_size = None;
321322
for (row, hash) in batch_hashes.into_iter().enumerate() {
322323
let entry = row_map.get_mut(hash, |(_hash, group_idx)| {
323324
// verify that a group that we are inserting with hash is
@@ -364,13 +365,15 @@ impl GroupedHashAggregateStream {
364365
+ (std::mem::size_of::<u32>() * group_state.indices.capacity());
365366

366367
// Allocation done by normal accumulators
367-
*allocated += (std::mem::size_of::<Box<dyn Accumulator>>()
368-
* group_state.accumulator_set.capacity())
369-
+ group_state
370-
.accumulator_set
371-
.iter()
372-
.map(|accu| accu.size())
373-
.sum::<usize>();
368+
*allocated += *accumulator_set_init_size.get_or_insert_with(|| {
369+
std::mem::size_of::<Box<dyn Accumulator>>()
370+
* group_state.accumulator_set.capacity()
371+
+ group_state
372+
.accumulator_set
373+
.iter()
374+
.map(|accu| accu.size())
375+
.sum::<usize>()
376+
});
374377

375378
// for hasher function, use precomputed hash value
376379
row_map.insert_accounted(
@@ -389,14 +392,23 @@ impl GroupedHashAggregateStream {
389392
}
390393

391394
// Update the accumulator results, according to row_aggr_state.
392-
fn update_accumulators(
395+
#[allow(clippy::too_many_arguments)]
396+
fn update_accumulators<F1, F2>(
393397
&mut self,
394398
groups_with_rows: &[usize],
395399
offsets: &[usize],
396400
row_values: &[Vec<ArrayRef>],
397401
normal_values: &[Vec<ArrayRef>],
402+
func_row: F1,
403+
func_normal: F2,
398404
allocated: &mut usize,
399-
) -> Result<()> {
405+
) -> Result<()>
406+
where
407+
F1: Fn(&mut RowAccumulatorItem, &mut RowAccessor, &[ArrayRef]) -> Result<()>,
408+
F2: Fn(&mut AccumulatorItem, &[ArrayRef]) -> Result<()>,
409+
{
410+
let accumulator_set_pre =
411+
get_accumulator_set_size(groups_with_rows, &self.row_aggr_state.group_states);
400412
// 2.1 for each key in this batch
401413
// 2.2 for each aggregation
402414
// 2.3 `slice` from each of its arrays the keys' values
@@ -428,15 +440,7 @@ impl GroupedHashAggregateStream {
428440
RowAccessor::new_from_layout(self.row_aggr_layout.clone());
429441
state_accessor
430442
.point_to(0, group_state.aggregation_buffer.as_mut_slice());
431-
match self.mode {
432-
AggregateMode::Partial => {
433-
accumulator.update_batch(&values, &mut state_accessor)
434-
}
435-
AggregateMode::FinalPartitioned | AggregateMode::Final => {
436-
// note: the aggregation here is over states, not values, thus the merge
437-
accumulator.merge_batch(&values, &mut state_accessor)
438-
}
439-
}
443+
func_row(accumulator, &mut state_accessor, &values)
440444
})
441445
// 2.5
442446
.and(Ok(()))?;
@@ -458,24 +462,17 @@ impl GroupedHashAggregateStream {
458462
)
459463
})
460464
.try_for_each(|(accumulator, values)| {
461-
let size_pre = accumulator.size();
462-
let res = match self.mode {
463-
AggregateMode::Partial => accumulator.update_batch(&values),
464-
AggregateMode::FinalPartitioned | AggregateMode::Final => {
465-
// note: the aggregation here is over states, not values, thus the merge
466-
accumulator.merge_batch(&values)
467-
}
468-
};
469-
let size_post = accumulator.size();
470-
*allocated += size_post.saturating_sub(size_pre);
471-
res
465+
func_normal(accumulator, &values)
472466
})
473467
// 2.5
474468
.and({
475469
group_state.indices.clear();
476470
Ok(())
477471
})
478472
})?;
473+
let accumulator_set_post =
474+
get_accumulator_set_size(groups_with_rows, &self.row_aggr_state.group_states);
475+
*allocated += accumulator_set_post.saturating_sub(accumulator_set_pre);
479476
Ok(())
480477
}
481478

@@ -517,13 +514,39 @@ impl GroupedHashAggregateStream {
517514
let row_values = get_at_indices(&row_aggr_input_values, &batch_indices)?;
518515
let normal_values =
519516
get_at_indices(&normal_aggr_input_values, &batch_indices)?;
520-
self.update_accumulators(
521-
&groups_with_rows,
522-
&offsets,
523-
&row_values,
524-
&normal_values,
525-
&mut allocated,
526-
)?;
517+
match self.mode {
518+
AggregateMode::Partial => self.update_accumulators(
519+
&groups_with_rows,
520+
&offsets,
521+
&row_values,
522+
&normal_values,
523+
|accumulator: &mut RowAccumulatorItem,
524+
state_accessor: &mut RowAccessor,
525+
values: &[ArrayRef]| {
526+
accumulator.update_batch(values, state_accessor)
527+
},
528+
|accumulator: &mut AccumulatorItem, values: &[ArrayRef]| {
529+
accumulator.update_batch(values)
530+
},
531+
&mut allocated,
532+
)?,
533+
AggregateMode::FinalPartitioned | AggregateMode::Final => self
534+
.update_accumulators(
535+
&groups_with_rows,
536+
&offsets,
537+
&row_values,
538+
&normal_values,
539+
|accumulator: &mut RowAccumulatorItem,
540+
state_accessor: &mut RowAccessor,
541+
values: &[ArrayRef]| {
542+
accumulator.merge_batch(values, state_accessor)
543+
},
544+
|accumulator: &mut AccumulatorItem, values: &[ArrayRef]| {
545+
accumulator.merge_batch(values)
546+
},
547+
&mut allocated,
548+
)?,
549+
};
527550
}
528551
allocated += self
529552
.row_converter
@@ -533,6 +556,19 @@ impl GroupedHashAggregateStream {
533556
}
534557
}
535558

559+
fn get_accumulator_set_size(
560+
groups_with_rows: &[usize],
561+
row_group_states: &[RowGroupState],
562+
) -> usize {
563+
groups_with_rows.iter().fold(0usize, |acc, group_idx| {
564+
let group_state = &row_group_states[*group_idx];
565+
group_state
566+
.accumulator_set
567+
.iter()
568+
.fold(acc, |acc, accumulator| acc + accumulator.size())
569+
})
570+
}
571+
536572
/// The state that is built for each output group.
537573
#[derive(Debug)]
538574
pub struct RowGroupState {

datafusion/physical-expr/src/aggregate/average.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ impl Accumulator for AvgAccumulator {
242242
}
243243

244244
fn size(&self) -> usize {
245-
std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size()
245+
std::mem::size_of_val(self)
246246
}
247247
}
248248

datafusion/physical-expr/src/aggregate/sum.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ impl Accumulator for SumAccumulator {
289289
}
290290

291291
fn size(&self) -> usize {
292-
std::mem::size_of_val(self) - std::mem::size_of_val(&self.sum) + self.sum.size()
292+
std::mem::size_of_val(self)
293293
}
294294
}
295295

0 commit comments

Comments
 (0)