Skip to content

Commit 4eefebe

Browse files
joroKr21cipherstakes
authored andcommitted
Respect ignore_nulls in array_agg (apache#15544)
* Respect ignore_nulls in array_agg * Reduce code duplication * Add another test
1 parent 9f2ef4d commit 4eefebe

File tree

3 files changed

+106
-27
lines changed

3 files changed

+106
-27
lines changed

datafusion/functions-aggregate/benches/array_agg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) {
4343
b.iter(|| {
4444
#[allow(clippy::unit_arg)]
4545
black_box(
46-
ArrayAggAccumulator::try_new(&list_item_data_type)
46+
ArrayAggAccumulator::try_new(&list_item_data_type, false)
4747
.unwrap()
4848
.merge_batch(&[values.clone()])
4949
.unwrap(),

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 73 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,10 @@
1717

1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20-
use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray};
21-
use arrow::compute::SortOptions;
20+
use arrow::array::{
21+
new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray,
22+
};
23+
use arrow::compute::{filter, SortOptions};
2224
use arrow::datatypes::{DataType, Field, Fields};
2325

2426
use datafusion_common::cast::as_list_array;
@@ -140,6 +142,8 @@ impl AggregateUDFImpl for ArrayAgg {
140142

141143
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
142144
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
145+
let ignore_nulls =
146+
acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
143147

144148
if acc_args.is_distinct {
145149
// Limitation similar to Postgres. The aggregation function can only mix
@@ -166,14 +170,19 @@ impl AggregateUDFImpl for ArrayAgg {
166170
}
167171
sort_option = Some(order.options)
168172
}
173+
169174
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
170175
&data_type,
171176
sort_option,
177+
ignore_nulls,
172178
)?));
173179
}
174180

175181
if acc_args.ordering_req.is_empty() {
176-
return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?));
182+
return Ok(Box::new(ArrayAggAccumulator::try_new(
183+
&data_type,
184+
ignore_nulls,
185+
)?));
177186
}
178187

179188
let ordering_dtypes = acc_args
@@ -187,6 +196,7 @@ impl AggregateUDFImpl for ArrayAgg {
187196
&ordering_dtypes,
188197
acc_args.ordering_req.clone(),
189198
acc_args.is_reversed,
199+
ignore_nulls,
190200
)
191201
.map(|acc| Box::new(acc) as _)
192202
}
@@ -204,18 +214,20 @@ impl AggregateUDFImpl for ArrayAgg {
204214
pub struct ArrayAggAccumulator {
205215
values: Vec<ArrayRef>,
206216
datatype: DataType,
217+
ignore_nulls: bool,
207218
}
208219

209220
impl ArrayAggAccumulator {
210221
/// new array_agg accumulator based on given item data type
211-
pub fn try_new(datatype: &DataType) -> Result<Self> {
222+
pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
212223
Ok(Self {
213224
values: vec![],
214225
datatype: datatype.clone(),
226+
ignore_nulls,
215227
})
216228
}
217229

218-
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
230+
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
219231
/// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
220232
fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
221233
let offsets = list_array.value_offsets();
@@ -239,15 +251,15 @@ impl ArrayAggAccumulator {
239251
return Some(list_array.values().slice(0, 0));
240252
}
241253

242-
// According to the Arrow spec, null values can point to non empty lists
254+
// According to the Arrow spec, null values can point to non-empty lists
243255
// So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
244256

245257
// Unwrapping is safe as we just checked if there is a null value
246258
let nulls = list_array.nulls().unwrap();
247259

248260
let mut valid_slices_iter = nulls.valid_slices();
249261

250-
// This is safe as we validated that that are at least 1 valid value in the array
262+
// This is safe as we validated that there is at least 1 valid value in the array
251263
let (start, end) = valid_slices_iter.next().unwrap();
252264

253265
let start_offset = offsets[start];
@@ -257,7 +269,7 @@ impl ArrayAggAccumulator {
257269
let mut end_offset_of_last_valid_value = offsets[end];
258270

259271
for (start, end) in valid_slices_iter {
260-
// If there is a null value that point to a non empty list than the start offset of the valid value
272+
// If there is a null value that point to a non-empty list than the start offset of the valid value
261273
// will be different that the end offset of the last valid value
262274
if offsets[start] != end_offset_of_last_valid_value {
263275
return None;
@@ -288,10 +300,23 @@ impl Accumulator for ArrayAggAccumulator {
288300
return internal_err!("expects single batch");
289301
}
290302

291-
let val = Arc::clone(&values[0]);
303+
let val = &values[0];
304+
let nulls = if self.ignore_nulls {
305+
val.logical_nulls()
306+
} else {
307+
None
308+
};
309+
310+
let val = match nulls {
311+
Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
312+
Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
313+
None => Arc::clone(val),
314+
};
315+
292316
if !val.is_empty() {
293317
self.values.push(val);
294318
}
319+
295320
Ok(())
296321
}
297322

@@ -360,17 +385,20 @@ struct DistinctArrayAggAccumulator {
360385
values: HashSet<ScalarValue>,
361386
datatype: DataType,
362387
sort_options: Option<SortOptions>,
388+
ignore_nulls: bool,
363389
}
364390

365391
impl DistinctArrayAggAccumulator {
366392
pub fn try_new(
367393
datatype: &DataType,
368394
sort_options: Option<SortOptions>,
395+
ignore_nulls: bool,
369396
) -> Result<Self> {
370397
Ok(Self {
371398
values: HashSet::new(),
372399
datatype: datatype.clone(),
373400
sort_options,
401+
ignore_nulls,
374402
})
375403
}
376404
}
@@ -385,11 +413,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
385413
return Ok(());
386414
}
387415

388-
let array = &values[0];
416+
let val = &values[0];
417+
let nulls = if self.ignore_nulls {
418+
val.logical_nulls()
419+
} else {
420+
None
421+
};
389422

390-
for i in 0..array.len() {
391-
let scalar = ScalarValue::try_from_array(&array, i)?;
392-
self.values.insert(scalar);
423+
let nulls = nulls.as_ref();
424+
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
425+
for i in 0..val.len() {
426+
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
427+
self.values.insert(ScalarValue::try_from_array(val, i)?);
428+
}
429+
}
393430
}
394431

395432
Ok(())
@@ -471,6 +508,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
471508
ordering_req: LexOrdering,
472509
/// Whether the aggregation is running in reverse.
473510
reverse: bool,
511+
/// Whether the aggregation should ignore null values.
512+
ignore_nulls: bool,
474513
}
475514

476515
impl OrderSensitiveArrayAggAccumulator {
@@ -481,6 +520,7 @@ impl OrderSensitiveArrayAggAccumulator {
481520
ordering_dtypes: &[DataType],
482521
ordering_req: LexOrdering,
483522
reverse: bool,
523+
ignore_nulls: bool,
484524
) -> Result<Self> {
485525
let mut datatypes = vec![datatype.clone()];
486526
datatypes.extend(ordering_dtypes.iter().cloned());
@@ -490,6 +530,7 @@ impl OrderSensitiveArrayAggAccumulator {
490530
datatypes,
491531
ordering_req,
492532
reverse,
533+
ignore_nulls,
493534
})
494535
}
495536
}
@@ -500,11 +541,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
500541
return Ok(());
501542
}
502543

503-
let n_row = values[0].len();
504-
for index in 0..n_row {
505-
let row = get_row_at_idx(values, index)?;
506-
self.values.push(row[0].clone());
507-
self.ordering_values.push(row[1..].to_vec());
544+
let val = &values[0];
545+
let ord = &values[1..];
546+
let nulls = if self.ignore_nulls {
547+
val.logical_nulls()
548+
} else {
549+
None
550+
};
551+
552+
let nulls = nulls.as_ref();
553+
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
554+
for i in 0..val.len() {
555+
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
556+
self.values.push(ScalarValue::try_from_array(val, i)?);
557+
self.ordering_values.push(get_row_at_idx(ord, i)?)
558+
}
559+
}
508560
}
509561

510562
Ok(())
@@ -665,7 +717,7 @@ impl OrderSensitiveArrayAggAccumulator {
665717
#[cfg(test)]
666718
mod tests {
667719
use super::*;
668-
use arrow::datatypes::{FieldRef, Schema};
720+
use arrow::datatypes::Schema;
669721
use datafusion_common::cast::as_generic_string_array;
670722
use datafusion_common::internal_err;
671723
use datafusion_physical_expr::expressions::Column;
@@ -946,14 +998,12 @@ mod tests {
946998
fn new(data_type: DataType) -> Self {
947999
Self {
9481000
data_type: data_type.clone(),
949-
distinct: Default::default(),
1001+
distinct: false,
9501002
ordering: Default::default(),
9511003
schema: Schema {
9521004
fields: Fields::from(vec![Field::new(
9531005
"col",
954-
DataType::List(FieldRef::new(Field::new(
955-
"item", data_type, true,
956-
))),
1006+
DataType::new_list(data_type, true),
9571007
true,
9581008
)]),
9591009
metadata: Default::default(),

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -303,17 +303,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES
303303
('b', [1,0]),
304304
('b', [1,0]),
305305
('b', [1,0]),
306-
('b', [0,1])
306+
('b', [0,1]),
307+
(NULL, [0,1]),
308+
('b', NULL)
307309
;
308310

309311
# Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort,
310312
# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs`
311313
query ??
312314
select array_sort(c1), array_sort(c2) from (
313-
select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table
315+
select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table
314316
);
315317
----
316-
[b, w] [[0, 1], [1, 0]]
318+
[NULL, b, w] [[0, 1], [1, 0]]
317319

318320
statement ok
319321
drop table array_agg_distinct_list_table;
@@ -3226,6 +3228,33 @@ select array_agg(column1) from t;
32263228
statement ok
32273229
drop table t;
32283230

3231+
# array_agg_ignore_nulls
3232+
statement ok
3233+
create table t as values (NULL, ''), (1, 'c'), (2, 'a'), (NULL, 'b'), (4, NULL), (NULL, NULL), (5, 'a');
3234+
3235+
query ?
3236+
select array_agg(column1) ignore nulls as c1 from t;
3237+
----
3238+
[1, 2, 4, 5]
3239+
3240+
query II
3241+
select count(*), array_length(array_agg(distinct column2) ignore nulls) from t;
3242+
----
3243+
7 4
3244+
3245+
query ?
3246+
select array_agg(column2 order by column1) ignore nulls from t;
3247+
----
3248+
[c, a, a, , b]
3249+
3250+
query ?
3251+
select array_agg(DISTINCT column2 order by column2) ignore nulls from t;
3252+
----
3253+
[, a, b, c]
3254+
3255+
statement ok
3256+
drop table t;
3257+
32293258
# variance_single_value
32303259
query RRRR
32313260
select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq;

0 commit comments

Comments
 (0)