1717
1818//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20- use arrow:: array:: { new_empty_array, Array , ArrayRef , AsArray , ListArray , StructArray } ;
21- use arrow:: compute:: SortOptions ;
22- use arrow:: datatypes:: DataType ;
20+ use arrow:: array:: {
21+ new_empty_array, Array , ArrayRef , AsArray , BooleanArray , ListArray , StructArray ,
22+ } ;
23+ use arrow:: compute:: { filter, SortOptions } ;
24+ use arrow:: datatypes:: { DataType , Field , Fields } ;
2325
24- use arrow_schema:: { Field , Fields } ;
2526use datafusion_common:: cast:: as_list_array;
2627use datafusion_common:: utils:: { get_row_at_idx, SingleRowListArrayBuilder } ;
2728use datafusion_common:: { exec_err, ScalarValue } ;
@@ -141,6 +142,8 @@ impl AggregateUDFImpl for ArrayAgg {
141142
142143 fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
143144 let data_type = acc_args. exprs [ 0 ] . data_type ( acc_args. schema ) ?;
145+ let ignore_nulls =
146+ acc_args. ignore_nulls && acc_args. exprs [ 0 ] . nullable ( acc_args. schema ) ?;
144147
145148 if acc_args. is_distinct {
146149 // Limitation similar to Postgres. The aggregation function can only mix
@@ -167,14 +170,19 @@ impl AggregateUDFImpl for ArrayAgg {
167170 }
168171 sort_option = Some ( order. options )
169172 }
173+
170174 return Ok ( Box :: new ( DistinctArrayAggAccumulator :: try_new (
171175 & data_type,
172176 sort_option,
177+ ignore_nulls,
173178 ) ?) ) ;
174179 }
175180
176181 if acc_args. ordering_req . is_empty ( ) {
177- return Ok ( Box :: new ( ArrayAggAccumulator :: try_new ( & data_type) ?) ) ;
182+ return Ok ( Box :: new ( ArrayAggAccumulator :: try_new (
183+ & data_type,
184+ ignore_nulls,
185+ ) ?) ) ;
178186 }
179187
180188 let ordering_dtypes = acc_args
@@ -188,6 +196,7 @@ impl AggregateUDFImpl for ArrayAgg {
188196 & ordering_dtypes,
189197 acc_args. ordering_req . clone ( ) ,
190198 acc_args. is_reversed ,
199+ ignore_nulls,
191200 )
192201 . map ( |acc| Box :: new ( acc) as _ )
193202 }
@@ -205,18 +214,20 @@ impl AggregateUDFImpl for ArrayAgg {
205214pub struct ArrayAggAccumulator {
206215 values : Vec < ArrayRef > ,
207216 datatype : DataType ,
217+ ignore_nulls : bool ,
208218}
209219
210220impl ArrayAggAccumulator {
211221 /// new array_agg accumulator based on given item data type
212- pub fn try_new ( datatype : & DataType ) -> Result < Self > {
222+ pub fn try_new ( datatype : & DataType , ignore_nulls : bool ) -> Result < Self > {
213223 Ok ( Self {
214224 values : vec ! [ ] ,
215225 datatype : datatype. clone ( ) ,
226+ ignore_nulls,
216227 } )
217228 }
218229
219- /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
230+ /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non- empty list)
220231 /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
221232 fn get_optional_values_to_merge_as_is ( list_array : & ListArray ) -> Option < ArrayRef > {
222233 let offsets = list_array. value_offsets ( ) ;
@@ -240,15 +251,15 @@ impl ArrayAggAccumulator {
240251 return Some ( list_array. values ( ) . slice ( 0 , 0 ) ) ;
241252 }
242253
243- // According to the Arrow spec, null values can point to non empty lists
254+ // According to the Arrow spec, null values can point to non- empty lists
244255 // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
245256
246257 // Unwrapping is safe as we just checked if there is a null value
247258 let nulls = list_array. nulls ( ) . unwrap ( ) ;
248259
249260 let mut valid_slices_iter = nulls. valid_slices ( ) ;
250261
251- // This is safe as we validated that that are at least 1 valid value in the array
262+ // This is safe as we validated that there is at least 1 valid value in the array
252263 let ( start, end) = valid_slices_iter. next ( ) . unwrap ( ) ;
253264
254265 let start_offset = offsets[ start] ;
@@ -258,7 +269,7 @@ impl ArrayAggAccumulator {
258269 let mut end_offset_of_last_valid_value = offsets[ end] ;
259270
260271 for ( start, end) in valid_slices_iter {
261- // If there is a null value that point to a non empty list than the start offset of the valid value
272+ // If there is a null value that point to a non- empty list than the start offset of the valid value
262273 // will be different that the end offset of the last valid value
263274 if offsets[ start] != end_offset_of_last_valid_value {
264275 return None ;
@@ -289,10 +300,23 @@ impl Accumulator for ArrayAggAccumulator {
289300 return internal_err ! ( "expects single batch" ) ;
290301 }
291302
292- let val = Arc :: clone ( & values[ 0 ] ) ;
293- if val. len ( ) > 0 {
303+ let val = & values[ 0 ] ;
304+ let nulls = if self . ignore_nulls {
305+ val. logical_nulls ( )
306+ } else {
307+ None
308+ } ;
309+
310+ let val = match nulls {
311+ Some ( nulls) if nulls. null_count ( ) >= val. len ( ) => return Ok ( ( ) ) ,
312+ Some ( nulls) => filter ( val, & BooleanArray :: new ( nulls. inner ( ) . clone ( ) , None ) ) ?,
313+ None => Arc :: clone ( val) ,
314+ } ;
315+
316+ if !val. is_empty ( ) {
294317 self . values . push ( val) ;
295318 }
319+
296320 Ok ( ( ) )
297321 }
298322
@@ -361,17 +385,20 @@ struct DistinctArrayAggAccumulator {
361385 values : HashSet < ScalarValue > ,
362386 datatype : DataType ,
363387 sort_options : Option < SortOptions > ,
388+ ignore_nulls : bool ,
364389}
365390
366391impl DistinctArrayAggAccumulator {
367392 pub fn try_new (
368393 datatype : & DataType ,
369394 sort_options : Option < SortOptions > ,
395+ ignore_nulls : bool ,
370396 ) -> Result < Self > {
371397 Ok ( Self {
372398 values : HashSet :: new ( ) ,
373399 datatype : datatype. clone ( ) ,
374400 sort_options,
401+ ignore_nulls,
375402 } )
376403 }
377404}
@@ -386,11 +413,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
386413 return Ok ( ( ) ) ;
387414 }
388415
389- let array = & values[ 0 ] ;
416+ let val = & values[ 0 ] ;
417+ let nulls = if self . ignore_nulls {
418+ val. logical_nulls ( )
419+ } else {
420+ None
421+ } ;
390422
391- for i in 0 ..array. len ( ) {
392- let scalar = ScalarValue :: try_from_array ( & array, i) ?;
393- self . values . insert ( scalar) ;
423+ let nulls = nulls. as_ref ( ) ;
424+ if nulls. is_none ( ) || nulls. unwrap ( ) . null_count ( ) < val. len ( ) {
425+ for i in 0 ..val. len ( ) {
426+ if nulls. is_none ( ) || nulls. unwrap ( ) . is_valid ( i) {
427+ self . values . insert ( ScalarValue :: try_from_array ( val, i) ?) ;
428+ }
429+ }
394430 }
395431
396432 Ok ( ( ) )
@@ -472,6 +508,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
472508 ordering_req : LexOrdering ,
473509 /// Whether the aggregation is running in reverse.
474510 reverse : bool ,
511+ /// Whether the aggregation should ignore null values.
512+ ignore_nulls : bool ,
475513}
476514
477515impl OrderSensitiveArrayAggAccumulator {
@@ -482,6 +520,7 @@ impl OrderSensitiveArrayAggAccumulator {
482520 ordering_dtypes : & [ DataType ] ,
483521 ordering_req : LexOrdering ,
484522 reverse : bool ,
523+ ignore_nulls : bool ,
485524 ) -> Result < Self > {
486525 let mut datatypes = vec ! [ datatype. clone( ) ] ;
487526 datatypes. extend ( ordering_dtypes. iter ( ) . cloned ( ) ) ;
@@ -491,6 +530,7 @@ impl OrderSensitiveArrayAggAccumulator {
491530 datatypes,
492531 ordering_req,
493532 reverse,
533+ ignore_nulls,
494534 } )
495535 }
496536}
@@ -501,11 +541,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
501541 return Ok ( ( ) ) ;
502542 }
503543
504- let n_row = values[ 0 ] . len ( ) ;
505- for index in 0 ..n_row {
506- let row = get_row_at_idx ( values, index) ?;
507- self . values . push ( row[ 0 ] . clone ( ) ) ;
508- self . ordering_values . push ( row[ 1 ..] . to_vec ( ) ) ;
544+ let val = & values[ 0 ] ;
545+ let ord = & values[ 1 ..] ;
546+ let nulls = if self . ignore_nulls {
547+ val. logical_nulls ( )
548+ } else {
549+ None
550+ } ;
551+
552+ let nulls = nulls. as_ref ( ) ;
553+ if nulls. is_none ( ) || nulls. unwrap ( ) . null_count ( ) < val. len ( ) {
554+ for i in 0 ..val. len ( ) {
555+ if nulls. is_none ( ) || nulls. unwrap ( ) . is_valid ( i) {
556+ self . values . push ( ScalarValue :: try_from_array ( val, i) ?) ;
557+ self . ordering_values . push ( get_row_at_idx ( ord, i) ?)
558+ }
559+ }
509560 }
510561
511562 Ok ( ( ) )
@@ -666,7 +717,7 @@ impl OrderSensitiveArrayAggAccumulator {
666717#[ cfg( test) ]
667718mod tests {
668719 use super :: * ;
669- use arrow:: datatypes:: { FieldRef , Schema } ;
720+ use arrow:: datatypes:: Schema ;
670721 use datafusion_common:: cast:: as_generic_string_array;
671722 use datafusion_common:: internal_err;
672723 use datafusion_physical_expr:: expressions:: Column ;
@@ -947,14 +998,12 @@ mod tests {
947998 fn new ( data_type : DataType ) -> Self {
948999 Self {
9491000 data_type : data_type. clone ( ) ,
950- distinct : Default :: default ( ) ,
1001+ distinct : false ,
9511002 ordering : Default :: default ( ) ,
9521003 schema : Schema {
9531004 fields : Fields :: from ( vec ! [ Field :: new(
9541005 "col" ,
955- DataType :: List ( FieldRef :: new( Field :: new(
956- "item" , data_type, true ,
957- ) ) ) ,
1006+ DataType :: new_list( data_type, true ) ,
9581007 true ,
9591008 ) ] ) ,
9601009 metadata : Default :: default ( ) ,
0 commit comments