@@ -318,6 +318,7 @@ impl GroupedHashAggregateStream {
318
318
..
319
319
} = & mut self . row_aggr_state ;
320
320
321
+ let mut accumulator_set_init_size = None ;
321
322
for ( row, hash) in batch_hashes. into_iter ( ) . enumerate ( ) {
322
323
let entry = row_map. get_mut ( hash, |( _hash, group_idx) | {
323
324
// verify that a group that we are inserting with hash is
@@ -364,13 +365,15 @@ impl GroupedHashAggregateStream {
364
365
+ ( std:: mem:: size_of :: < u32 > ( ) * group_state. indices . capacity ( ) ) ;
365
366
366
367
// Allocation done by normal accumulators
367
- * allocated += ( std:: mem:: size_of :: < Box < dyn Accumulator > > ( )
368
- * group_state. accumulator_set . capacity ( ) )
369
- + group_state
370
- . accumulator_set
371
- . iter ( )
372
- . map ( |accu| accu. size ( ) )
373
- . sum :: < usize > ( ) ;
368
+ * allocated += * accumulator_set_init_size. get_or_insert_with ( || {
369
+ std:: mem:: size_of :: < Box < dyn Accumulator > > ( )
370
+ * group_state. accumulator_set . capacity ( )
371
+ + group_state
372
+ . accumulator_set
373
+ . iter ( )
374
+ . map ( |accu| accu. size ( ) )
375
+ . sum :: < usize > ( )
376
+ } ) ;
374
377
375
378
// for hasher function, use precomputed hash value
376
379
row_map. insert_accounted (
@@ -389,14 +392,23 @@ impl GroupedHashAggregateStream {
389
392
}
390
393
391
394
// Update the accumulator results, according to row_aggr_state.
392
- fn update_accumulators (
395
+ #[ allow( clippy:: too_many_arguments) ]
396
+ fn update_accumulators < F1 , F2 > (
393
397
& mut self ,
394
398
groups_with_rows : & [ usize ] ,
395
399
offsets : & [ usize ] ,
396
400
row_values : & [ Vec < ArrayRef > ] ,
397
401
normal_values : & [ Vec < ArrayRef > ] ,
402
+ func_row : F1 ,
403
+ func_normal : F2 ,
398
404
allocated : & mut usize ,
399
- ) -> Result < ( ) > {
405
+ ) -> Result < ( ) >
406
+ where
407
+ F1 : Fn ( & mut RowAccumulatorItem , & mut RowAccessor , & [ ArrayRef ] ) -> Result < ( ) > ,
408
+ F2 : Fn ( & mut AccumulatorItem , & [ ArrayRef ] ) -> Result < ( ) > ,
409
+ {
410
+ let accumulator_set_pre =
411
+ get_accumulator_set_size ( groups_with_rows, & self . row_aggr_state . group_states ) ;
400
412
// 2.1 for each key in this batch
401
413
// 2.2 for each aggregation
402
414
// 2.3 `slice` from each of its arrays the keys' values
@@ -428,15 +440,7 @@ impl GroupedHashAggregateStream {
428
440
RowAccessor :: new_from_layout ( self . row_aggr_layout . clone ( ) ) ;
429
441
state_accessor
430
442
. point_to ( 0 , group_state. aggregation_buffer . as_mut_slice ( ) ) ;
431
- match self . mode {
432
- AggregateMode :: Partial => {
433
- accumulator. update_batch ( & values, & mut state_accessor)
434
- }
435
- AggregateMode :: FinalPartitioned | AggregateMode :: Final => {
436
- // note: the aggregation here is over states, not values, thus the merge
437
- accumulator. merge_batch ( & values, & mut state_accessor)
438
- }
439
- }
443
+ func_row ( accumulator, & mut state_accessor, & values)
440
444
} )
441
445
// 2.5
442
446
. and ( Ok ( ( ) ) ) ?;
@@ -458,24 +462,17 @@ impl GroupedHashAggregateStream {
458
462
)
459
463
} )
460
464
. try_for_each ( |( accumulator, values) | {
461
- let size_pre = accumulator. size ( ) ;
462
- let res = match self . mode {
463
- AggregateMode :: Partial => accumulator. update_batch ( & values) ,
464
- AggregateMode :: FinalPartitioned | AggregateMode :: Final => {
465
- // note: the aggregation here is over states, not values, thus the merge
466
- accumulator. merge_batch ( & values)
467
- }
468
- } ;
469
- let size_post = accumulator. size ( ) ;
470
- * allocated += size_post. saturating_sub ( size_pre) ;
471
- res
465
+ func_normal ( accumulator, & values)
472
466
} )
473
467
// 2.5
474
468
. and ( {
475
469
group_state. indices . clear ( ) ;
476
470
Ok ( ( ) )
477
471
} )
478
472
} ) ?;
473
+ let accumulator_set_post =
474
+ get_accumulator_set_size ( groups_with_rows, & self . row_aggr_state . group_states ) ;
475
+ * allocated += accumulator_set_post. saturating_sub ( accumulator_set_pre) ;
479
476
Ok ( ( ) )
480
477
}
481
478
@@ -517,13 +514,39 @@ impl GroupedHashAggregateStream {
517
514
let row_values = get_at_indices ( & row_aggr_input_values, & batch_indices) ?;
518
515
let normal_values =
519
516
get_at_indices ( & normal_aggr_input_values, & batch_indices) ?;
520
- self . update_accumulators (
521
- & groups_with_rows,
522
- & offsets,
523
- & row_values,
524
- & normal_values,
525
- & mut allocated,
526
- ) ?;
517
+ match self . mode {
518
+ AggregateMode :: Partial => self . update_accumulators (
519
+ & groups_with_rows,
520
+ & offsets,
521
+ & row_values,
522
+ & normal_values,
523
+ |accumulator : & mut RowAccumulatorItem ,
524
+ state_accessor : & mut RowAccessor ,
525
+ values : & [ ArrayRef ] | {
526
+ accumulator. update_batch ( values, state_accessor)
527
+ } ,
528
+ |accumulator : & mut AccumulatorItem , values : & [ ArrayRef ] | {
529
+ accumulator. update_batch ( values)
530
+ } ,
531
+ & mut allocated,
532
+ ) ?,
533
+ AggregateMode :: FinalPartitioned | AggregateMode :: Final => self
534
+ . update_accumulators (
535
+ & groups_with_rows,
536
+ & offsets,
537
+ & row_values,
538
+ & normal_values,
539
+ |accumulator : & mut RowAccumulatorItem ,
540
+ state_accessor : & mut RowAccessor ,
541
+ values : & [ ArrayRef ] | {
542
+ accumulator. merge_batch ( values, state_accessor)
543
+ } ,
544
+ |accumulator : & mut AccumulatorItem , values : & [ ArrayRef ] | {
545
+ accumulator. merge_batch ( values)
546
+ } ,
547
+ & mut allocated,
548
+ ) ?,
549
+ } ;
527
550
}
528
551
allocated += self
529
552
. row_converter
@@ -533,6 +556,19 @@ impl GroupedHashAggregateStream {
533
556
}
534
557
}
535
558
559
+ fn get_accumulator_set_size (
560
+ groups_with_rows : & [ usize ] ,
561
+ row_group_states : & [ RowGroupState ] ,
562
+ ) -> usize {
563
+ groups_with_rows. iter ( ) . fold ( 0usize , |acc, group_idx| {
564
+ let group_state = & row_group_states[ * group_idx] ;
565
+ group_state
566
+ . accumulator_set
567
+ . iter ( )
568
+ . fold ( acc, |acc, accumulator| acc + accumulator. size ( ) )
569
+ } )
570
+ }
571
+
536
572
/// The state that is built for each output group.
537
573
#[ derive( Debug ) ]
538
574
pub struct RowGroupState {
0 commit comments