@@ -20,7 +20,11 @@ use std::fmt::{Debug, Formatter};
2020use std:: mem:: { size_of, size_of_val} ;
2121use std:: sync:: Arc ;
2222
23- use arrow:: array:: { downcast_integer, ArrowNumericType } ;
23+ use arrow:: array:: {
24+ downcast_integer, ArrowNumericType , BooleanArray , ListArray , PrimitiveArray ,
25+ PrimitiveBuilder ,
26+ } ;
27+ use arrow:: buffer:: { OffsetBuffer , ScalarBuffer } ;
2428use arrow:: {
2529 array:: { ArrayRef , AsArray } ,
2630 datatypes:: {
@@ -33,12 +37,17 @@ use arrow::array::Array;
3337use arrow:: array:: ArrowNativeTypeOp ;
3438use arrow:: datatypes:: { ArrowNativeType , ArrowPrimitiveType } ;
3539
36- use datafusion_common:: { DataFusionError , HashSet , Result , ScalarValue } ;
40+ use datafusion_common:: {
41+ internal_datafusion_err, internal_err, DataFusionError , HashSet , Result , ScalarValue ,
42+ } ;
3743use datafusion_expr:: function:: StateFieldsArgs ;
3844use datafusion_expr:: {
3945 function:: AccumulatorArgs , utils:: format_state_name, Accumulator , AggregateUDFImpl ,
4046 Documentation , Signature , Volatility ,
4147} ;
48+ use datafusion_expr:: { EmitTo , GroupsAccumulator } ;
49+ use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: accumulate:: accumulate;
50+ use datafusion_functions_aggregate_common:: aggregate:: groups_accumulator:: nulls:: filtered_null_mask;
4251use datafusion_functions_aggregate_common:: utils:: Hashable ;
4352use datafusion_macros:: user_doc;
4453
@@ -165,6 +174,45 @@ impl AggregateUDFImpl for Median {
165174 }
166175 }
167176
177+ fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
178+ !args. is_distinct
179+ }
180+
181+ fn create_groups_accumulator (
182+ & self ,
183+ args : AccumulatorArgs ,
184+ ) -> Result < Box < dyn GroupsAccumulator > > {
185+ let num_args = args. exprs . len ( ) ;
186+ if num_args != 1 {
187+ return internal_err ! (
188+ "median should only have 1 arg, but found num args:{}" ,
189+ args. exprs. len( )
190+ ) ;
191+ }
192+
193+ let dt = args. exprs [ 0 ] . data_type ( args. schema ) ?;
194+
195+ macro_rules! helper {
196+ ( $t: ty, $dt: expr) => {
197+ Ok ( Box :: new( MedianGroupsAccumulator :: <$t>:: new( $dt) ) )
198+ } ;
199+ }
200+
201+ downcast_integer ! {
202+ dt => ( helper, dt) ,
203+ DataType :: Float16 => helper!( Float16Type , dt) ,
204+ DataType :: Float32 => helper!( Float32Type , dt) ,
205+ DataType :: Float64 => helper!( Float64Type , dt) ,
206+ DataType :: Decimal128 ( _, _) => helper!( Decimal128Type , dt) ,
207+ DataType :: Decimal256 ( _, _) => helper!( Decimal256Type , dt) ,
208+ _ => Err ( DataFusionError :: NotImplemented ( format!(
209+ "MedianGroupsAccumulator not supported for {} with {}" ,
210+ args. name,
211+ dt,
212+ ) ) ) ,
213+ }
214+ }
215+
168216 fn aliases ( & self ) -> & [ String ] {
169217 & [ ]
170218 }
@@ -230,6 +278,216 @@ impl<T: ArrowNumericType> Accumulator for MedianAccumulator<T> {
230278 }
231279}
232280
281+ /// The median groups accumulator accumulates the raw input values
282+ ///
283+ /// For calculating the accurate medians of groups, we need to store all values
284+ /// of groups before final evaluation.
285+ /// So values in each group will be stored in a `Vec<T>`, and the total group values
286+ /// will be actually organized as a `Vec<Vec<T>>`.
287+ ///
288+ #[ derive( Debug ) ]
289+ struct MedianGroupsAccumulator < T : ArrowNumericType + Send > {
290+ data_type : DataType ,
291+ group_values : Vec < Vec < T :: Native > > ,
292+ }
293+
294+ impl < T : ArrowNumericType + Send > MedianGroupsAccumulator < T > {
295+ pub fn new ( data_type : DataType ) -> Self {
296+ Self {
297+ data_type,
298+ group_values : Vec :: new ( ) ,
299+ }
300+ }
301+ }
302+
303+ impl < T : ArrowNumericType + Send > GroupsAccumulator for MedianGroupsAccumulator < T > {
304+ fn update_batch (
305+ & mut self ,
306+ values : & [ ArrayRef ] ,
307+ group_indices : & [ usize ] ,
308+ opt_filter : Option < & BooleanArray > ,
309+ total_num_groups : usize ,
310+ ) -> Result < ( ) > {
311+ assert_eq ! ( values. len( ) , 1 , "single argument to update_batch" ) ;
312+ let values = values[ 0 ] . as_primitive :: < T > ( ) ;
313+
314+ // Push the `not nulls + not filtered` row into its group
315+ self . group_values . resize ( total_num_groups, Vec :: new ( ) ) ;
316+ accumulate (
317+ group_indices,
318+ values,
319+ opt_filter,
320+ |group_index, new_value| {
321+ self . group_values [ group_index] . push ( new_value) ;
322+ } ,
323+ ) ;
324+
325+ Ok ( ( ) )
326+ }
327+
328+ fn merge_batch (
329+ & mut self ,
330+ values : & [ ArrayRef ] ,
331+ group_indices : & [ usize ] ,
332+ // Since aggregate filter should be applied in partial stage, in final stage there should be no filter
333+ _opt_filter : Option < & BooleanArray > ,
334+ total_num_groups : usize ,
335+ ) -> Result < ( ) > {
336+ assert_eq ! ( values. len( ) , 1 , "one argument to merge_batch" ) ;
337+
338+ // The merged values should be organized like as a `ListArray` which is nullable
339+ // (input with nulls usually generated from `convert_to_state`), but `inner array` of
340+ // `ListArray` is `non-nullable`.
341+ //
342+ // Following is the possible and impossible input `values`:
343+ //
344+ // # Possible values
345+ // ```text
346+ // group 0: [1, 2, 3]
347+ // group 1: null (list array is nullable)
348+ // group 2: [6, 7, 8]
349+ // ...
350+ // group n: [...]
351+ // ```
352+ //
353+ // # Impossible values
354+ // ```text
355+ // group x: [1, 2, null] (values in list array is non-nullable)
356+ // ```
357+ //
358+ let input_group_values = values[ 0 ] . as_list :: < i32 > ( ) ;
359+
360+ // Ensure group values big enough
361+ self . group_values . resize ( total_num_groups, Vec :: new ( ) ) ;
362+
363+ // Extend values to related groups
364+ // TODO: avoid using iterator of the `ListArray`, this will lead to
365+ // many calls of `slice` of its ``inner array`, and `slice` is not
366+ // so efficient(due to the calculation of `null_count` for each `slice`).
367+ group_indices
368+ . iter ( )
369+ . zip ( input_group_values. iter ( ) )
370+ . for_each ( |( & group_index, values_opt) | {
371+ if let Some ( values) = values_opt {
372+ let values = values. as_primitive :: < T > ( ) ;
373+ self . group_values [ group_index] . extend ( values. values ( ) . iter ( ) ) ;
374+ }
375+ } ) ;
376+
377+ Ok ( ( ) )
378+ }
379+
380+ fn state ( & mut self , emit_to : EmitTo ) -> Result < Vec < ArrayRef > > {
381+ // Emit values
382+ let emit_group_values = emit_to. take_needed ( & mut self . group_values ) ;
383+
384+ // Build offsets
385+ let mut offsets = Vec :: with_capacity ( self . group_values . len ( ) + 1 ) ;
386+ offsets. push ( 0 ) ;
387+ let mut cur_len = 0_i32 ;
388+ for group_value in & emit_group_values {
389+ cur_len += group_value. len ( ) as i32 ;
390+ offsets. push ( cur_len) ;
391+ }
392+ // TODO: maybe we can use `OffsetBuffer::new_unchecked` like what in `convert_to_state`,
393+ // but safety should be considered more carefully here(and I am not sure if it can get
394+ // performance improvement when we introduce checks to keep the safety...).
395+ //
396+ // Can see more details in:
397+ // https://github.com/apache/datafusion/pull/13681#discussion_r1931209791
398+ //
399+ let offsets = OffsetBuffer :: new ( ScalarBuffer :: from ( offsets) ) ;
400+
401+ // Build inner array
402+ let flatten_group_values =
403+ emit_group_values. into_iter ( ) . flatten ( ) . collect :: < Vec < _ > > ( ) ;
404+ let group_values_array =
405+ PrimitiveArray :: < T > :: new ( ScalarBuffer :: from ( flatten_group_values) , None )
406+ . with_data_type ( self . data_type . clone ( ) ) ;
407+
408+ // Build the result list array
409+ let result_list_array = ListArray :: new (
410+ Arc :: new ( Field :: new_list_field ( self . data_type . clone ( ) , true ) ) ,
411+ offsets,
412+ Arc :: new ( group_values_array) ,
413+ None ,
414+ ) ;
415+
416+ Ok ( vec ! [ Arc :: new( result_list_array) ] )
417+ }
418+
419+ fn evaluate ( & mut self , emit_to : EmitTo ) -> Result < ArrayRef > {
420+ // Emit values
421+ let emit_group_values = emit_to. take_needed ( & mut self . group_values ) ;
422+
423+ // Calculate median for each group
424+ let mut evaluate_result_builder =
425+ PrimitiveBuilder :: < T > :: new ( ) . with_data_type ( self . data_type . clone ( ) ) ;
426+ for values in emit_group_values {
427+ let median = calculate_median :: < T > ( values) ;
428+ evaluate_result_builder. append_option ( median) ;
429+ }
430+
431+ Ok ( Arc :: new ( evaluate_result_builder. finish ( ) ) )
432+ }
433+
434+ fn convert_to_state (
435+ & self ,
436+ values : & [ ArrayRef ] ,
437+ opt_filter : Option < & BooleanArray > ,
438+ ) -> Result < Vec < ArrayRef > > {
439+ assert_eq ! ( values. len( ) , 1 , "one argument to merge_batch" ) ;
440+
441+ let input_array = values[ 0 ] . as_primitive :: < T > ( ) ;
442+
443+ // Directly convert the input array to states, each row will be
444+ // seen as a respective group.
445+ // For detail, the `input_array` will be converted to a `ListArray`.
446+ // And if row is `not null + not filtered`, it will be converted to a list
447+ // with only one element; otherwise, this row in `ListArray` will be set
448+ // to null.
449+
450+ // Reuse values buffer in `input_array` to build `values` in `ListArray`
451+ let values = PrimitiveArray :: < T > :: new ( input_array. values ( ) . clone ( ) , None )
452+ . with_data_type ( self . data_type . clone ( ) ) ;
453+
454+ // `offsets` in `ListArray`, each row as a list element
455+ let offset_end = i32:: try_from ( input_array. len ( ) ) . map_err ( |e| {
456+ internal_datafusion_err ! (
457+ "cast array_len to i32 failed in convert_to_state of group median, err:{e:?}"
458+ )
459+ } ) ?;
460+ let offsets = ( 0 ..=offset_end) . collect :: < Vec < _ > > ( ) ;
461+ // Safety: all checks in `OffsetBuffer::new` are ensured to pass
462+ let offsets = unsafe { OffsetBuffer :: new_unchecked ( ScalarBuffer :: from ( offsets) ) } ;
463+
464+ // `nulls` for converted `ListArray`
465+ let nulls = filtered_null_mask ( opt_filter, input_array) ;
466+
467+ let converted_list_array = ListArray :: new (
468+ Arc :: new ( Field :: new_list_field ( self . data_type . clone ( ) , true ) ) ,
469+ offsets,
470+ Arc :: new ( values) ,
471+ nulls,
472+ ) ;
473+
474+ Ok ( vec ! [ Arc :: new( converted_list_array) ] )
475+ }
476+
477+ fn supports_convert_to_state ( & self ) -> bool {
478+ true
479+ }
480+
481+ fn size ( & self ) -> usize {
482+ self . group_values
483+ . iter ( )
484+ . map ( |values| values. capacity ( ) * size_of :: < T > ( ) )
485+ . sum :: < usize > ( )
486+ // account for size of self.grou_values too
487+ + self . group_values . capacity ( ) * size_of :: < Vec < T > > ( )
488+ }
489+ }
490+
233491/// The distinct median accumulator accumulates the raw input values
234492/// as `ScalarValue`s
235493///
0 commit comments