@@ -654,9 +654,13 @@ impl Stream for GroupedHashAggregateStream {
654654 }
655655
656656 if let Some ( to_emit) = self . group_ordering . emit_to ( ) {
657- let batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
658- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
659657 timer. done ( ) ;
658+ if let Some ( batch) =
659+ extract_ok ! ( self . emit( to_emit, false ) )
660+ {
661+ self . exec_state =
662+ ExecutionState :: ProducingOutput ( batch) ;
663+ } ;
660664 // make sure the exec_state just set is not overwritten below
661665 break ' reading_input;
662666 }
@@ -693,9 +697,13 @@ impl Stream for GroupedHashAggregateStream {
693697 }
694698
695699 if let Some ( to_emit) = self . group_ordering . emit_to ( ) {
696- let batch = extract_ok ! ( self . emit( to_emit, false ) ) ;
697- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
698700 timer. done ( ) ;
701+ if let Some ( batch) =
702+ extract_ok ! ( self . emit( to_emit, false ) )
703+ {
704+ self . exec_state =
705+ ExecutionState :: ProducingOutput ( batch) ;
706+ } ;
699707 // make sure the exec_state just set is not overwritten below
700708 break ' reading_input;
701709 }
@@ -768,6 +776,9 @@ impl Stream for GroupedHashAggregateStream {
768776 let output = batch. slice ( 0 , size) ;
769777 ( ExecutionState :: ProducingOutput ( remaining) , output)
770778 } ;
779+ // Empty record batches should not be emitted.
780+ // They need to be treated as [`Option<RecordBatch>`]es and handled separately
781+ debug_assert ! ( output_batch. num_rows( ) > 0 ) ;
771782 return Poll :: Ready ( Some ( Ok (
772783 output_batch. record_output ( & self . baseline_metrics )
773784 ) ) ) ;
@@ -902,14 +913,14 @@ impl GroupedHashAggregateStream {
902913
903914 /// Create an output RecordBatch with the group keys and
904915 /// accumulator states/values specified in emit_to
905- fn emit ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < RecordBatch > {
916+ fn emit ( & mut self , emit_to : EmitTo , spilling : bool ) -> Result < Option < RecordBatch > > {
906917 let schema = if spilling {
907918 Arc :: clone ( & self . spill_state . spill_schema )
908919 } else {
909920 self . schema ( )
910921 } ;
911922 if self . group_values . is_empty ( ) {
912- return Ok ( RecordBatch :: new_empty ( schema ) ) ;
923+ return Ok ( None ) ;
913924 }
914925
915926 let mut output = self . group_values . emit ( emit_to) ?;
@@ -937,7 +948,8 @@ impl GroupedHashAggregateStream {
937948 // over the target memory size after emission, we can emit again rather than returning Err.
938949 let _ = self . update_memory_reservation ( ) ;
939950 let batch = RecordBatch :: try_new ( schema, output) ?;
940- Ok ( batch)
951+ debug_assert ! ( batch. num_rows( ) > 0 ) ;
952+ Ok ( Some ( batch) )
941953 }
942954
943955 /// Optimistically, [`Self::group_aggregate_batch`] allows to exceed the memory target slightly
@@ -963,7 +975,9 @@ impl GroupedHashAggregateStream {
963975
964976 /// Emit all rows, sort them, and store them on disk.
965977 fn spill ( & mut self ) -> Result < ( ) > {
966- let emit = self . emit ( EmitTo :: All , true ) ?;
978+ let Some ( emit) = self . emit ( EmitTo :: All , true ) ? else {
979+ return Ok ( ( ) ) ;
980+ } ;
967981 let sorted = sort_batch ( & emit, self . spill_state . spill_expr . as_ref ( ) , None ) ?;
968982 let spillfile = self . runtime . disk_manager . create_tmp_file ( "HashAggSpill" ) ?;
969983 // TODO: slice large `sorted` and write to multiple files in parallel
@@ -1008,8 +1022,9 @@ impl GroupedHashAggregateStream {
10081022 {
10091023 assert_eq ! ( self . mode, AggregateMode :: Partial ) ;
10101024 let n = self . group_values . len ( ) / self . batch_size * self . batch_size ;
1011- let batch = self . emit ( EmitTo :: First ( n) , false ) ?;
1012- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1025+ if let Some ( batch) = self . emit ( EmitTo :: First ( n) , false ) ? {
1026+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1027+ } ;
10131028 }
10141029 Ok ( ( ) )
10151030 }
@@ -1019,7 +1034,9 @@ impl GroupedHashAggregateStream {
10191034 /// Conduct a streaming merge sort between the batch and spilled data. Since the stream is fully
10201035 /// sorted, set `self.group_ordering` to Full, then later we can read with [`EmitTo::First`].
10211036 fn update_merged_stream ( & mut self ) -> Result < ( ) > {
1022- let batch = self . emit ( EmitTo :: All , true ) ?;
1037+ let Some ( batch) = self . emit ( EmitTo :: All , true ) ? else {
1038+ return Ok ( ( ) ) ;
1039+ } ;
10231040 // clear up memory for streaming_merge
10241041 self . clear_all ( ) ;
10251042 self . update_memory_reservation ( ) ?;
@@ -1067,7 +1084,7 @@ impl GroupedHashAggregateStream {
10671084 let timer = elapsed_compute. timer ( ) ;
10681085 self . exec_state = if self . spill_state . spills . is_empty ( ) {
10691086 let batch = self . emit ( EmitTo :: All , false ) ?;
1070- ExecutionState :: ProducingOutput ( batch )
1087+ batch . map_or ( ExecutionState :: Done , ExecutionState :: ProducingOutput )
10711088 } else {
10721089 // If spill files exist, stream-merge them.
10731090 self . update_merged_stream ( ) ?;
@@ -1096,8 +1113,9 @@ impl GroupedHashAggregateStream {
10961113 fn switch_to_skip_aggregation ( & mut self ) -> Result < ( ) > {
10971114 if let Some ( probe) = self . skip_aggregation_probe . as_mut ( ) {
10981115 if probe. should_skip ( ) {
1099- let batch = self . emit ( EmitTo :: All , false ) ?;
1100- self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1116+ if let Some ( batch) = self . emit ( EmitTo :: All , false ) ? {
1117+ self . exec_state = ExecutionState :: ProducingOutput ( batch) ;
1118+ } ;
11011119 }
11021120 }
11031121
0 commit comments