@@ -21,10 +21,6 @@ use std::any::Any;
2121use std:: sync:: Arc ;
2222
2323use super :: { DisplayAs , ExecutionPlanProperties , PlanProperties } ;
24- use crate :: aggregates:: {
25- no_grouping:: AggregateStream , row_hash:: GroupedHashAggregateStream ,
26- topk_stream:: GroupedTopKAggregateStream ,
27- } ;
2824use crate :: execution_plan:: { CardinalityEffect , EmissionType } ;
2925use crate :: metrics:: { ExecutionPlanMetricsSet , MetricsSet } ;
3026use crate :: windows:: get_ordered_partition_by_indices;
@@ -358,21 +354,10 @@ impl PartialEq for PhysicalGroupBy {
358354 }
359355}
360356
361- #[ allow( clippy:: large_enum_variant) ]
362357enum StreamType {
363- AggregateStream ( AggregateStream ) ,
364- GroupedHash ( GroupedHashAggregateStream ) ,
365- GroupedPriorityQueue ( GroupedTopKAggregateStream ) ,
366- }
367-
368- impl From < StreamType > for SendableRecordBatchStream {
369- fn from ( stream : StreamType ) -> Self {
370- match stream {
371- StreamType :: AggregateStream ( stream) => Box :: pin ( stream) ,
372- StreamType :: GroupedHash ( stream) => Box :: pin ( stream) ,
373- StreamType :: GroupedPriorityQueue ( stream) => Box :: pin ( stream) ,
374- }
375- }
358+ AggregateStream ( SendableRecordBatchStream ) ,
359+ GroupedHash ( SendableRecordBatchStream ) ,
360+ GroupedPriorityQueue ( SendableRecordBatchStream ) ,
376361}
377362
378363/// Hash aggregate execution plan
@@ -608,7 +593,7 @@ impl AggregateExec {
608593 ) -> Result < StreamType > {
609594 // no group by at all
610595 if self . group_by . expr . is_empty ( ) {
611- return Ok ( StreamType :: AggregateStream ( AggregateStream :: new (
596+ return Ok ( StreamType :: AggregateStream ( no_grouping :: aggregate_stream (
612597 self , context, partition,
613598 ) ?) ) ;
614599 }
@@ -617,13 +602,13 @@ impl AggregateExec {
617602 if let Some ( limit) = self . limit {
618603 if !self . is_unordered_unfiltered_group_by_distinct ( ) {
619604 return Ok ( StreamType :: GroupedPriorityQueue (
620- GroupedTopKAggregateStream :: new ( self , context, partition, limit) ?,
605+ topk_stream :: aggregate_stream ( self , context, partition, limit) ?,
621606 ) ) ;
622607 }
623608 }
624609
625610 // grouping by something else and we need to just materialize all results
626- Ok ( StreamType :: GroupedHash ( GroupedHashAggregateStream :: new (
611+ Ok ( StreamType :: GroupedHash ( row_hash :: aggregate_stream (
627612 self , context, partition,
628613 ) ?) )
629614 }
@@ -998,8 +983,11 @@ impl ExecutionPlan for AggregateExec {
998983 partition : usize ,
999984 context : Arc < TaskContext > ,
1000985 ) -> Result < SendableRecordBatchStream > {
1001- self . execute_typed ( partition, context)
1002- . map ( |stream| stream. into ( ) )
986+ match self . execute_typed ( partition, context) ? {
987+ StreamType :: AggregateStream ( s) => Ok ( s) ,
988+ StreamType :: GroupedHash ( s) => Ok ( s) ,
989+ StreamType :: GroupedPriorityQueue ( s) => Ok ( s) ,
990+ }
1003991 }
1004992
1005993 fn metrics ( & self ) -> Option < MetricsSet > {
@@ -1274,7 +1262,7 @@ pub fn create_accumulators(
12741262/// final value (mode = Final, FinalPartitioned and Single) or states (mode = Partial)
12751263pub fn finalize_aggregation (
12761264 accumulators : & mut [ AccumulatorItem ] ,
1277- mode : & AggregateMode ,
1265+ mode : AggregateMode ,
12781266) -> Result < Vec < ArrayRef > > {
12791267 match mode {
12801268 AggregateMode :: Partial => {
@@ -2105,20 +2093,20 @@ mod tests {
21052093 let stream = partial_aggregate. execute_typed ( 0 , Arc :: clone ( & task_ctx) ) ?;
21062094
21072095 // ensure that we really got the version we wanted
2108- match version {
2109- 0 => {
2110- assert ! ( matches!( stream, StreamType :: AggregateStream ( _) ) ) ;
2096+ let stream = match stream {
2097+ StreamType :: AggregateStream ( s) => {
2098+ assert_eq ! ( version, 0 ) ;
2099+ s
21112100 }
2112- 1 => {
2113- assert ! ( matches!( stream, StreamType :: GroupedHash ( _) ) ) ;
2101+ StreamType :: GroupedHash ( s) => {
2102+ assert ! ( version == 1 || version == 2 ) ;
2103+ s
21142104 }
2115- 2 => {
2116- assert ! ( matches! ( stream, StreamType :: GroupedHash ( _ ) ) ) ;
2105+ StreamType :: GroupedPriorityQueue ( _ ) => {
2106+ panic ! ( "Unexpected GroupedPriorityQueue stream type" ) ;
21172107 }
2118- _ => panic ! ( "Unknown version: {version}" ) ,
2119- }
2108+ } ;
21202109
2121- let stream: SendableRecordBatchStream = stream. into ( ) ;
21222110 let err = collect ( stream) . await . unwrap_err ( ) ;
21232111
21242112 // error root cause traversal is a bit complicated, see #4172.
0 commit comments