1717
1818//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20- use arrow:: array:: { new_empty_array, Array , ArrayRef , AsArray , ListArray , StructArray } ;
21- use arrow:: compute:: SortOptions ;
20+ use arrow:: array:: {
21+ new_empty_array, Array , ArrayRef , AsArray , BooleanArray , ListArray , StructArray ,
22+ } ;
23+ use arrow:: compute:: { filter, SortOptions } ;
2224use arrow:: datatypes:: { DataType , Field , Fields } ;
2325
2426use datafusion_common:: cast:: as_list_array;
@@ -140,6 +142,8 @@ impl AggregateUDFImpl for ArrayAgg {
140142
141143 fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
142144 let data_type = acc_args. exprs [ 0 ] . data_type ( acc_args. schema ) ?;
145+ let ignore_nulls =
146+ acc_args. ignore_nulls && acc_args. exprs [ 0 ] . nullable ( acc_args. schema ) ?;
143147
144148 if acc_args. is_distinct {
145149 // Limitation similar to Postgres. The aggregation function can only mix
@@ -166,14 +170,19 @@ impl AggregateUDFImpl for ArrayAgg {
166170 }
167171 sort_option = Some ( order. options )
168172 }
173+
169174 return Ok ( Box :: new ( DistinctArrayAggAccumulator :: try_new (
170175 & data_type,
171176 sort_option,
177+ ignore_nulls,
172178 ) ?) ) ;
173179 }
174180
175181 if acc_args. ordering_req . is_empty ( ) {
176- return Ok ( Box :: new ( ArrayAggAccumulator :: try_new ( & data_type) ?) ) ;
182+ return Ok ( Box :: new ( ArrayAggAccumulator :: try_new (
183+ & data_type,
184+ ignore_nulls,
185+ ) ?) ) ;
177186 }
178187
179188 let ordering_dtypes = acc_args
@@ -187,6 +196,7 @@ impl AggregateUDFImpl for ArrayAgg {
187196 & ordering_dtypes,
188197 acc_args. ordering_req . clone ( ) ,
189198 acc_args. is_reversed ,
199+ ignore_nulls,
190200 )
191201 . map ( |acc| Box :: new ( acc) as _ )
192202 }
@@ -204,18 +214,20 @@ impl AggregateUDFImpl for ArrayAgg {
204214pub struct ArrayAggAccumulator {
205215 values : Vec < ArrayRef > ,
206216 datatype : DataType ,
217+ ignore_nulls : bool ,
207218}
208219
209220impl ArrayAggAccumulator {
210221 /// new array_agg accumulator based on given item data type
211- pub fn try_new ( datatype : & DataType ) -> Result < Self > {
222+ pub fn try_new ( datatype : & DataType , ignore_nulls : bool ) -> Result < Self > {
212223 Ok ( Self {
213224 values : vec ! [ ] ,
214225 datatype : datatype. clone ( ) ,
226+ ignore_nulls,
215227 } )
216228 }
217229
218- /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
230+ /// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non- empty list)
219231 /// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
220232 fn get_optional_values_to_merge_as_is ( list_array : & ListArray ) -> Option < ArrayRef > {
221233 let offsets = list_array. value_offsets ( ) ;
@@ -239,15 +251,15 @@ impl ArrayAggAccumulator {
239251 return Some ( list_array. values ( ) . slice ( 0 , 0 ) ) ;
240252 }
241253
242- // According to the Arrow spec, null values can point to non empty lists
254+ // According to the Arrow spec, null values can point to non- empty lists
243255 // So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
244256
245257 // Unwrapping is safe as we just checked if there is a null value
246258 let nulls = list_array. nulls ( ) . unwrap ( ) ;
247259
248260 let mut valid_slices_iter = nulls. valid_slices ( ) ;
249261
250- // This is safe as we validated that that are at least 1 valid value in the array
262+ // This is safe as we validated that there is at least 1 valid value in the array
251263 let ( start, end) = valid_slices_iter. next ( ) . unwrap ( ) ;
252264
253265 let start_offset = offsets[ start] ;
@@ -257,7 +269,7 @@ impl ArrayAggAccumulator {
257269 let mut end_offset_of_last_valid_value = offsets[ end] ;
258270
259271 for ( start, end) in valid_slices_iter {
260- // If there is a null value that point to a non empty list than the start offset of the valid value
272+ // If there is a null value that point to a non- empty list than the start offset of the valid value
261273 // will be different that the end offset of the last valid value
262274 if offsets[ start] != end_offset_of_last_valid_value {
263275 return None ;
@@ -288,10 +300,23 @@ impl Accumulator for ArrayAggAccumulator {
288300 return internal_err ! ( "expects single batch" ) ;
289301 }
290302
291- let val = Arc :: clone ( & values[ 0 ] ) ;
303+ let val = & values[ 0 ] ;
304+ let nulls = if self . ignore_nulls {
305+ val. logical_nulls ( )
306+ } else {
307+ None
308+ } ;
309+
310+ let val = match nulls {
311+ Some ( nulls) if nulls. null_count ( ) >= val. len ( ) => return Ok ( ( ) ) ,
312+ Some ( nulls) => filter ( val, & BooleanArray :: new ( nulls. inner ( ) . clone ( ) , None ) ) ?,
313+ None => Arc :: clone ( val) ,
314+ } ;
315+
292316 if !val. is_empty ( ) {
293317 self . values . push ( val) ;
294318 }
319+
295320 Ok ( ( ) )
296321 }
297322
@@ -360,17 +385,20 @@ struct DistinctArrayAggAccumulator {
360385 values : HashSet < ScalarValue > ,
361386 datatype : DataType ,
362387 sort_options : Option < SortOptions > ,
388+ ignore_nulls : bool ,
363389}
364390
365391impl DistinctArrayAggAccumulator {
366392 pub fn try_new (
367393 datatype : & DataType ,
368394 sort_options : Option < SortOptions > ,
395+ ignore_nulls : bool ,
369396 ) -> Result < Self > {
370397 Ok ( Self {
371398 values : HashSet :: new ( ) ,
372399 datatype : datatype. clone ( ) ,
373400 sort_options,
401+ ignore_nulls,
374402 } )
375403 }
376404}
@@ -385,11 +413,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
385413 return Ok ( ( ) ) ;
386414 }
387415
388- let array = & values[ 0 ] ;
416+ let val = & values[ 0 ] ;
417+ let nulls = if self . ignore_nulls {
418+ val. logical_nulls ( )
419+ } else {
420+ None
421+ } ;
389422
390- for i in 0 ..array. len ( ) {
391- let scalar = ScalarValue :: try_from_array ( & array, i) ?;
392- self . values . insert ( scalar) ;
423+ let nulls = nulls. as_ref ( ) ;
424+ if nulls. is_none_or ( |nulls| nulls. null_count ( ) < val. len ( ) ) {
425+ for i in 0 ..val. len ( ) {
426+ if nulls. is_none_or ( |nulls| nulls. is_valid ( i) ) {
427+ self . values . insert ( ScalarValue :: try_from_array ( val, i) ?) ;
428+ }
429+ }
393430 }
394431
395432 Ok ( ( ) )
@@ -471,6 +508,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
471508 ordering_req : LexOrdering ,
472509 /// Whether the aggregation is running in reverse.
473510 reverse : bool ,
511+ /// Whether the aggregation should ignore null values.
512+ ignore_nulls : bool ,
474513}
475514
476515impl OrderSensitiveArrayAggAccumulator {
@@ -481,6 +520,7 @@ impl OrderSensitiveArrayAggAccumulator {
481520 ordering_dtypes : & [ DataType ] ,
482521 ordering_req : LexOrdering ,
483522 reverse : bool ,
523+ ignore_nulls : bool ,
484524 ) -> Result < Self > {
485525 let mut datatypes = vec ! [ datatype. clone( ) ] ;
486526 datatypes. extend ( ordering_dtypes. iter ( ) . cloned ( ) ) ;
@@ -490,6 +530,7 @@ impl OrderSensitiveArrayAggAccumulator {
490530 datatypes,
491531 ordering_req,
492532 reverse,
533+ ignore_nulls,
493534 } )
494535 }
495536}
@@ -500,11 +541,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
500541 return Ok ( ( ) ) ;
501542 }
502543
503- let n_row = values[ 0 ] . len ( ) ;
504- for index in 0 ..n_row {
505- let row = get_row_at_idx ( values, index) ?;
506- self . values . push ( row[ 0 ] . clone ( ) ) ;
507- self . ordering_values . push ( row[ 1 ..] . to_vec ( ) ) ;
544+ let val = & values[ 0 ] ;
545+ let ord = & values[ 1 ..] ;
546+ let nulls = if self . ignore_nulls {
547+ val. logical_nulls ( )
548+ } else {
549+ None
550+ } ;
551+
552+ let nulls = nulls. as_ref ( ) ;
553+ if nulls. is_none_or ( |nulls| nulls. null_count ( ) < val. len ( ) ) {
554+ for i in 0 ..val. len ( ) {
555+ if nulls. is_none_or ( |nulls| nulls. is_valid ( i) ) {
556+ self . values . push ( ScalarValue :: try_from_array ( val, i) ?) ;
557+ self . ordering_values . push ( get_row_at_idx ( ord, i) ?)
558+ }
559+ }
508560 }
509561
510562 Ok ( ( ) )
@@ -665,7 +717,7 @@ impl OrderSensitiveArrayAggAccumulator {
665717#[ cfg( test) ]
666718mod tests {
667719 use super :: * ;
668- use arrow:: datatypes:: { FieldRef , Schema } ;
720+ use arrow:: datatypes:: Schema ;
669721 use datafusion_common:: cast:: as_generic_string_array;
670722 use datafusion_common:: internal_err;
671723 use datafusion_physical_expr:: expressions:: Column ;
@@ -946,14 +998,12 @@ mod tests {
946998 fn new ( data_type : DataType ) -> Self {
947999 Self {
9481000 data_type : data_type. clone ( ) ,
949- distinct : Default :: default ( ) ,
1001+ distinct : false ,
9501002 ordering : Default :: default ( ) ,
9511003 schema : Schema {
9521004 fields : Fields :: from ( vec ! [ Field :: new(
9531005 "col" ,
954- DataType :: List ( FieldRef :: new( Field :: new(
955- "item" , data_type, true ,
956- ) ) ) ,
1006+ DataType :: new_list( data_type, true ) ,
9571007 true ,
9581008 ) ] ) ,
9591009 metadata : Default :: default ( ) ,
0 commit comments