Skip to content

Commit 1ed2a05

Browse files
joroKr21findepi
authored andcommitted
Respect ignore_nulls in array_agg (apache#15544)
* Respect ignore_nulls in array_agg * Reduce code duplication * Add another test (cherry picked from commit 5bb0a98)
1 parent a96af27 commit 1ed2a05

File tree

3 files changed

+108
-30
lines changed

3 files changed

+108
-30
lines changed

datafusion/functions-aggregate/benches/array_agg.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) {
3636
b.iter(|| {
3737
#[allow(clippy::unit_arg)]
3838
black_box(
39-
ArrayAggAccumulator::try_new(&list_item_data_type)
39+
ArrayAggAccumulator::try_new(&list_item_data_type, false)
4040
.unwrap()
4141
.merge_batch(&[values.clone()])
4242
.unwrap(),

datafusion/functions-aggregate/src/array_agg.rs

Lines changed: 75 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,12 @@
1717

1818
//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]
1919
20-
use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray};
21-
use arrow::compute::SortOptions;
22-
use arrow::datatypes::DataType;
20+
use arrow::array::{
21+
new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray,
22+
};
23+
use arrow::compute::{filter, SortOptions};
24+
use arrow::datatypes::{DataType, Field, Fields};
2325

24-
use arrow_schema::{Field, Fields};
2526
use datafusion_common::cast::as_list_array;
2627
use datafusion_common::utils::{get_row_at_idx, SingleRowListArrayBuilder};
2728
use datafusion_common::{exec_err, ScalarValue};
@@ -141,6 +142,8 @@ impl AggregateUDFImpl for ArrayAgg {
141142

142143
fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
143144
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
145+
let ignore_nulls =
146+
acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;
144147

145148
if acc_args.is_distinct {
146149
// Limitation similar to Postgres. The aggregation function can only mix
@@ -167,14 +170,19 @@ impl AggregateUDFImpl for ArrayAgg {
167170
}
168171
sort_option = Some(order.options)
169172
}
173+
170174
return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
171175
&data_type,
172176
sort_option,
177+
ignore_nulls,
173178
)?));
174179
}
175180

176181
if acc_args.ordering_req.is_empty() {
177-
return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?));
182+
return Ok(Box::new(ArrayAggAccumulator::try_new(
183+
&data_type,
184+
ignore_nulls,
185+
)?));
178186
}
179187

180188
let ordering_dtypes = acc_args
@@ -188,6 +196,7 @@ impl AggregateUDFImpl for ArrayAgg {
188196
&ordering_dtypes,
189197
acc_args.ordering_req.clone(),
190198
acc_args.is_reversed,
199+
ignore_nulls,
191200
)
192201
.map(|acc| Box::new(acc) as _)
193202
}
@@ -205,18 +214,20 @@ impl AggregateUDFImpl for ArrayAgg {
205214
pub struct ArrayAggAccumulator {
206215
values: Vec<ArrayRef>,
207216
datatype: DataType,
217+
ignore_nulls: bool,
208218
}
209219

210220
impl ArrayAggAccumulator {
211221
/// new array_agg accumulator based on given item data type
212-
pub fn try_new(datatype: &DataType) -> Result<Self> {
222+
pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
213223
Ok(Self {
214224
values: vec![],
215225
datatype: datatype.clone(),
226+
ignore_nulls,
216227
})
217228
}
218229

219-
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
230+
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
220231
/// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
221232
fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
222233
let offsets = list_array.value_offsets();
@@ -240,15 +251,15 @@ impl ArrayAggAccumulator {
240251
return Some(list_array.values().slice(0, 0));
241252
}
242253

243-
// According to the Arrow spec, null values can point to non empty lists
254+
// According to the Arrow spec, null values can point to non-empty lists
244255
// So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value
245256

246257
// Unwrapping is safe as we just checked if there is a null value
247258
let nulls = list_array.nulls().unwrap();
248259

249260
let mut valid_slices_iter = nulls.valid_slices();
250261

251-
// This is safe as we validated that that are at least 1 valid value in the array
262+
// This is safe as we validated that there is at least 1 valid value in the array
252263
let (start, end) = valid_slices_iter.next().unwrap();
253264

254265
let start_offset = offsets[start];
@@ -258,7 +269,7 @@ impl ArrayAggAccumulator {
258269
let mut end_offset_of_last_valid_value = offsets[end];
259270

260271
for (start, end) in valid_slices_iter {
261-
// If there is a null value that point to a non empty list than the start offset of the valid value
272+
// If there is a null value that point to a non-empty list than the start offset of the valid value
262273
// will be different that the end offset of the last valid value
263274
if offsets[start] != end_offset_of_last_valid_value {
264275
return None;
@@ -289,10 +300,23 @@ impl Accumulator for ArrayAggAccumulator {
289300
return internal_err!("expects single batch");
290301
}
291302

292-
let val = Arc::clone(&values[0]);
293-
if val.len() > 0 {
303+
let val = &values[0];
304+
let nulls = if self.ignore_nulls {
305+
val.logical_nulls()
306+
} else {
307+
None
308+
};
309+
310+
let val = match nulls {
311+
Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
312+
Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
313+
None => Arc::clone(val),
314+
};
315+
316+
if !val.is_empty() {
294317
self.values.push(val);
295318
}
319+
296320
Ok(())
297321
}
298322

@@ -361,17 +385,20 @@ struct DistinctArrayAggAccumulator {
361385
values: HashSet<ScalarValue>,
362386
datatype: DataType,
363387
sort_options: Option<SortOptions>,
388+
ignore_nulls: bool,
364389
}
365390

366391
impl DistinctArrayAggAccumulator {
367392
pub fn try_new(
368393
datatype: &DataType,
369394
sort_options: Option<SortOptions>,
395+
ignore_nulls: bool,
370396
) -> Result<Self> {
371397
Ok(Self {
372398
values: HashSet::new(),
373399
datatype: datatype.clone(),
374400
sort_options,
401+
ignore_nulls,
375402
})
376403
}
377404
}
@@ -386,11 +413,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
386413
return Ok(());
387414
}
388415

389-
let array = &values[0];
416+
let val = &values[0];
417+
let nulls = if self.ignore_nulls {
418+
val.logical_nulls()
419+
} else {
420+
None
421+
};
390422

391-
for i in 0..array.len() {
392-
let scalar = ScalarValue::try_from_array(&array, i)?;
393-
self.values.insert(scalar);
423+
let nulls = nulls.as_ref();
424+
if nulls.is_none() || nulls.unwrap().null_count() < val.len() {
425+
for i in 0..val.len() {
426+
if nulls.is_none() || nulls.unwrap().is_valid(i) {
427+
self.values.insert(ScalarValue::try_from_array(val, i)?);
428+
}
429+
}
394430
}
395431

396432
Ok(())
@@ -472,6 +508,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
472508
ordering_req: LexOrdering,
473509
/// Whether the aggregation is running in reverse.
474510
reverse: bool,
511+
/// Whether the aggregation should ignore null values.
512+
ignore_nulls: bool,
475513
}
476514

477515
impl OrderSensitiveArrayAggAccumulator {
@@ -482,6 +520,7 @@ impl OrderSensitiveArrayAggAccumulator {
482520
ordering_dtypes: &[DataType],
483521
ordering_req: LexOrdering,
484522
reverse: bool,
523+
ignore_nulls: bool,
485524
) -> Result<Self> {
486525
let mut datatypes = vec![datatype.clone()];
487526
datatypes.extend(ordering_dtypes.iter().cloned());
@@ -491,6 +530,7 @@ impl OrderSensitiveArrayAggAccumulator {
491530
datatypes,
492531
ordering_req,
493532
reverse,
533+
ignore_nulls,
494534
})
495535
}
496536
}
@@ -501,11 +541,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
501541
return Ok(());
502542
}
503543

504-
let n_row = values[0].len();
505-
for index in 0..n_row {
506-
let row = get_row_at_idx(values, index)?;
507-
self.values.push(row[0].clone());
508-
self.ordering_values.push(row[1..].to_vec());
544+
let val = &values[0];
545+
let ord = &values[1..];
546+
let nulls = if self.ignore_nulls {
547+
val.logical_nulls()
548+
} else {
549+
None
550+
};
551+
552+
let nulls = nulls.as_ref();
553+
if nulls.is_none() || nulls.unwrap().null_count() < val.len() {
554+
for i in 0..val.len() {
555+
if nulls.is_none() || nulls.unwrap().is_valid(i) {
556+
self.values.push(ScalarValue::try_from_array(val, i)?);
557+
self.ordering_values.push(get_row_at_idx(ord, i)?)
558+
}
559+
}
509560
}
510561

511562
Ok(())
@@ -666,7 +717,7 @@ impl OrderSensitiveArrayAggAccumulator {
666717
#[cfg(test)]
667718
mod tests {
668719
use super::*;
669-
use arrow::datatypes::{FieldRef, Schema};
720+
use arrow::datatypes::Schema;
670721
use datafusion_common::cast::as_generic_string_array;
671722
use datafusion_common::internal_err;
672723
use datafusion_physical_expr::expressions::Column;
@@ -947,14 +998,12 @@ mod tests {
947998
fn new(data_type: DataType) -> Self {
948999
Self {
9491000
data_type: data_type.clone(),
950-
distinct: Default::default(),
1001+
distinct: false,
9511002
ordering: Default::default(),
9521003
schema: Schema {
9531004
fields: Fields::from(vec![Field::new(
9541005
"col",
955-
DataType::List(FieldRef::new(Field::new(
956-
"item", data_type, true,
957-
))),
1006+
DataType::new_list(data_type, true),
9581007
true,
9591008
)]),
9601009
metadata: Default::default(),

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,17 +289,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES
289289
('b', [1,0]),
290290
('b', [1,0]),
291291
('b', [1,0]),
292-
('b', [0,1])
292+
('b', [0,1]),
293+
(NULL, [0,1]),
294+
('b', NULL)
293295
;
294296

295297
# Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort,
296298
# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs`
297299
query ??
298300
select array_sort(c1), array_sort(c2) from (
299-
select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table
301+
select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table
300302
);
301303
----
302-
[b, w] [[0, 1], [1, 0]]
304+
[NULL, b, w] [[0, 1], [1, 0]]
303305

304306
statement ok
305307
drop table array_agg_distinct_list_table;
@@ -3194,6 +3196,33 @@ select array_agg(column1) from t;
31943196
statement ok
31953197
drop table t;
31963198

3199+
# array_agg_ignore_nulls
3200+
statement ok
3201+
create table t as values (NULL, ''), (1, 'c'), (2, 'a'), (NULL, 'b'), (4, NULL), (NULL, NULL), (5, 'a');
3202+
3203+
query ?
3204+
select array_agg(column1) ignore nulls as c1 from t;
3205+
----
3206+
[1, 2, 4, 5]
3207+
3208+
query II
3209+
select count(*), array_length(array_agg(distinct column2) ignore nulls) from t;
3210+
----
3211+
7 4
3212+
3213+
query ?
3214+
select array_agg(column2 order by column1) ignore nulls from t;
3215+
----
3216+
[c, a, a, , b]
3217+
3218+
query ?
3219+
select array_agg(DISTINCT column2 order by column2) ignore nulls from t;
3220+
----
3221+
[, a, b, c]
3222+
3223+
statement ok
3224+
drop table t;
3225+
31973226
# variance_single_value
31983227
query RRRR
31993228
select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq;

0 commit comments

Comments
 (0)