@@ -30,7 +30,7 @@ use crate::aggregates::{
3030use crate :: metrics:: { BaselineMetrics , MetricBuilder , RecordOutput } ;
3131use crate :: sorts:: sort:: sort_batch;
3232use crate :: sorts:: streaming_merge:: StreamingMergeBuilder ;
33- use crate :: spill:: { read_spill_as_stream , spill_record_batch_by_size } ;
33+ use crate :: spill:: spill_manager :: SpillManager ;
3434use crate :: stream:: RecordBatchStreamAdapter ;
3535use crate :: { aggregates, metrics, ExecutionPlan , PhysicalExpr } ;
3636use crate :: { RecordBatchStream , SendableRecordBatchStream } ;
@@ -42,7 +42,6 @@ use datafusion_common::{internal_err, DataFusionError, Result};
4242use datafusion_execution:: disk_manager:: RefCountedTempFile ;
4343use datafusion_execution:: memory_pool:: proxy:: VecAllocExt ;
4444use datafusion_execution:: memory_pool:: { MemoryConsumer , MemoryReservation } ;
45- use datafusion_execution:: runtime_env:: RuntimeEnv ;
4645use datafusion_execution:: TaskContext ;
4746use datafusion_expr:: { EmitTo , GroupsAccumulator } ;
4847use datafusion_physical_expr:: expressions:: Column ;
@@ -91,6 +90,9 @@ struct SpillState {
9190 /// GROUP BY expressions for merging spilled data
9291 merging_group_by : PhysicalGroupBy ,
9392
93+ /// Manages the process of spilling and reading back intermediate data
94+ spill_manager : SpillManager ,
95+
9496 // ========================================================================
9597 // STATES:
9698 // Fields changes during execution. Can be buffer, or state flags that
@@ -109,12 +111,7 @@ struct SpillState {
109111 /// Peak memory used for buffered data.
110112 /// Calculated as sum of peak memory values across partitions
111113 peak_mem_used : metrics:: Gauge ,
112- /// count of spill files during the execution of the operator
113- spill_count : metrics:: Count ,
114- /// total spilled bytes during the execution of the operator
115- spilled_bytes : metrics:: Count ,
116- /// total spilled rows during the execution of the operator
117- spilled_rows : metrics:: Count ,
114+ // Metrics related to spilling are managed inside `spill_manager`
118115}
119116
120117/// Tracks if the aggregate should skip partial aggregations
@@ -435,9 +432,6 @@ pub(crate) struct GroupedHashAggregateStream {
435432
436433 /// Execution metrics
437434 baseline_metrics : BaselineMetrics ,
438-
439- /// The [`RuntimeEnv`] associated with the [`TaskContext`] argument
440- runtime : Arc < RuntimeEnv > ,
441435}
442436
443437impl GroupedHashAggregateStream {
@@ -544,6 +538,12 @@ impl GroupedHashAggregateStream {
544538
545539 let exec_state = ExecutionState :: ReadingInput ;
546540
541+ let spill_manager = SpillManager :: new (
542+ context. runtime_env ( ) ,
543+ metrics:: SpillMetrics :: new ( & agg. metrics , partition) ,
544+ Arc :: clone ( & partial_agg_schema) ,
545+ ) ;
546+
547547 let spill_state = SpillState {
548548 spills : vec ! [ ] ,
549549 spill_expr,
@@ -553,9 +553,7 @@ impl GroupedHashAggregateStream {
553553 merging_group_by : PhysicalGroupBy :: new_single ( agg_group_by. expr . clone ( ) ) ,
554554 peak_mem_used : MetricBuilder :: new ( & agg. metrics )
555555 . gauge ( "peak_mem_used" , partition) ,
556- spill_count : MetricBuilder :: new ( & agg. metrics ) . spill_count ( partition) ,
557- spilled_bytes : MetricBuilder :: new ( & agg. metrics ) . spilled_bytes ( partition) ,
558- spilled_rows : MetricBuilder :: new ( & agg. metrics ) . spilled_rows ( partition) ,
556+ spill_manager,
559557 } ;
560558
561559 // Skip aggregation is supported if:
@@ -604,7 +602,6 @@ impl GroupedHashAggregateStream {
604602 batch_size,
605603 group_ordering,
606604 input_done : false ,
607- runtime : context. runtime_env ( ) ,
608605 spill_state,
609606 group_values_soft_limit : agg. limit ,
610607 skip_aggregation_probe,
@@ -981,28 +978,30 @@ impl GroupedHashAggregateStream {
981978 Ok ( ( ) )
982979 }
983980
984- /// Emit all rows, sort them, and store them on disk.
981+ /// Emit all intermediate aggregation states, sort them, and store them on disk.
982+ /// This process helps in reducing memory pressure by allowing the data to be
983+ /// read back with streaming merge.
985984 fn spill ( & mut self ) -> Result < ( ) > {
985+ // Emit and sort intermediate aggregation state
986986 let Some ( emit) = self . emit ( EmitTo :: All , true ) ? else {
987987 return Ok ( ( ) ) ;
988988 } ;
989989 let sorted = sort_batch ( & emit, self . spill_state . spill_expr . as_ref ( ) , None ) ?;
990- let spillfile = self . runtime . disk_manager . create_tmp_file ( "HashAggSpill" ) ? ;
991- // TODO: slice large ` sorted` and write to multiple files in parallel
992- spill_record_batch_by_size (
990+
991+ // Spill sorted state to disk
992+ let spillfile = self . spill_state . spill_manager . spill_record_batch_by_size (
993993 & sorted,
994- spillfile. path ( ) . into ( ) ,
995- sorted. schema ( ) ,
994+ "HashAggSpill" ,
996995 self . batch_size ,
997996 ) ?;
998- self . spill_state . spills . push ( spillfile) ;
999-
1000- // Update metrics
1001- self . spill_state . spill_count . add ( 1 ) ;
1002- self . spill_state
1003- . spilled_bytes
1004- . add ( sorted . get_array_memory_size ( ) ) ;
1005- self . spill_state . spilled_rows . add ( sorted . num_rows ( ) ) ;
997+ match spillfile {
998+ Some ( spillfile ) => self . spill_state . spills . push ( spillfile ) ,
999+ None => {
1000+ return internal_err ! (
1001+ "Calling spill with no intermediate batch to spill"
1002+ ) ;
1003+ }
1004+ }
10061005
10071006 Ok ( ( ) )
10081007 }
@@ -1058,7 +1057,7 @@ impl GroupedHashAggregateStream {
10581057 } ) ) ,
10591058 ) ) ) ;
10601059 for spill in self . spill_state . spills . drain ( ..) {
1061- let stream = read_spill_as_stream ( spill, Arc :: clone ( & schema ) , 2 ) ?;
1060+ let stream = self . spill_state . spill_manager . read_spill_as_stream ( spill) ?;
10621061 streams. push ( stream) ;
10631062 }
10641063 self . spill_state . is_stream_merging = true ;
0 commit comments