Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion datafusion/functions-aggregate/benches/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ fn merge_batch_bench(c: &mut Criterion, name: &str, values: ArrayRef) {
b.iter(|| {
#[allow(clippy::unit_arg)]
black_box(
ArrayAggAccumulator::try_new(&list_item_data_type)
ArrayAggAccumulator::try_new(&list_item_data_type, false)
.unwrap()
.merge_batch(&[values.clone()])
.unwrap(),
Expand Down
96 changes: 73 additions & 23 deletions datafusion/functions-aggregate/src/array_agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@

//! `ARRAY_AGG` aggregate implementation: [`ArrayAgg`]

use arrow::array::{new_empty_array, Array, ArrayRef, AsArray, ListArray, StructArray};
use arrow::compute::SortOptions;
use arrow::array::{
new_empty_array, Array, ArrayRef, AsArray, BooleanArray, ListArray, StructArray,
};
use arrow::compute::{filter, SortOptions};
use arrow::datatypes::{DataType, Field, Fields};

use datafusion_common::cast::as_list_array;
Expand Down Expand Up @@ -140,6 +142,8 @@ impl AggregateUDFImpl for ArrayAgg {

fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
let data_type = acc_args.exprs[0].data_type(acc_args.schema)?;
let ignore_nulls =
acc_args.ignore_nulls && acc_args.exprs[0].nullable(acc_args.schema)?;

if acc_args.is_distinct {
// Limitation similar to Postgres. The aggregation function can only mix
Expand All @@ -166,14 +170,19 @@ impl AggregateUDFImpl for ArrayAgg {
}
sort_option = Some(order.options)
}

return Ok(Box::new(DistinctArrayAggAccumulator::try_new(
&data_type,
sort_option,
ignore_nulls,
)?));
}

if acc_args.ordering_req.is_empty() {
return Ok(Box::new(ArrayAggAccumulator::try_new(&data_type)?));
return Ok(Box::new(ArrayAggAccumulator::try_new(
&data_type,
ignore_nulls,
)?));
}

let ordering_dtypes = acc_args
Expand All @@ -187,6 +196,7 @@ impl AggregateUDFImpl for ArrayAgg {
&ordering_dtypes,
acc_args.ordering_req.clone(),
acc_args.is_reversed,
ignore_nulls,
)
.map(|acc| Box::new(acc) as _)
}
Expand All @@ -204,18 +214,20 @@ impl AggregateUDFImpl for ArrayAgg {
pub struct ArrayAggAccumulator {
values: Vec<ArrayRef>,
datatype: DataType,
ignore_nulls: bool,
}

impl ArrayAggAccumulator {
/// new array_agg accumulator based on given item data type
pub fn try_new(datatype: &DataType) -> Result<Self> {
pub fn try_new(datatype: &DataType, ignore_nulls: bool) -> Result<Self> {
Ok(Self {
values: vec![],
datatype: datatype.clone(),
ignore_nulls,
})
}

/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non empty list)
/// This function will return the underlying list array values if all valid values are consecutive without gaps (i.e. no null value point to a non-empty list)
/// If there are gaps but only in the end of the list array, the function will return the values without the null values in the end
fn get_optional_values_to_merge_as_is(list_array: &ListArray) -> Option<ArrayRef> {
let offsets = list_array.value_offsets();
Expand All @@ -239,15 +251,15 @@ impl ArrayAggAccumulator {
return Some(list_array.values().slice(0, 0));
}

// According to the Arrow spec, null values can point to non empty lists
// According to the Arrow spec, null values can point to non-empty lists
// So this will check if all null values starting from the first valid value to the last one point to a 0 length list so we can just slice the underlying value

// Unwrapping is safe as we just checked if there is a null value
let nulls = list_array.nulls().unwrap();

let mut valid_slices_iter = nulls.valid_slices();

// This is safe as we validated that that are at least 1 valid value in the array
// This is safe as we validated that there is at least 1 valid value in the array
let (start, end) = valid_slices_iter.next().unwrap();

let start_offset = offsets[start];
Expand All @@ -257,7 +269,7 @@ impl ArrayAggAccumulator {
let mut end_offset_of_last_valid_value = offsets[end];

for (start, end) in valid_slices_iter {
// If there is a null value that point to a non empty list than the start offset of the valid value
// If there is a null value that point to a non-empty list than the start offset of the valid value
// will be different that the end offset of the last valid value
if offsets[start] != end_offset_of_last_valid_value {
return None;
Expand Down Expand Up @@ -288,10 +300,23 @@ impl Accumulator for ArrayAggAccumulator {
return internal_err!("expects single batch");
}

let val = Arc::clone(&values[0]);
let val = &values[0];
let nulls = if self.ignore_nulls {
val.logical_nulls()
} else {
None
};

let val = match nulls {
Some(nulls) if nulls.null_count() >= val.len() => return Ok(()),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how could the null count be greater than the length? or is the >= a defensive coding mechanism?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It shouldn't be - I guess we can use == but does it make a difference?

Some(nulls) => filter(val, &BooleanArray::new(nulls.inner().clone(), None))?,
None => Arc::clone(val),
};

if !val.is_empty() {
self.values.push(val);
}

Ok(())
}

Expand Down Expand Up @@ -360,17 +385,20 @@ struct DistinctArrayAggAccumulator {
values: HashSet<ScalarValue>,
datatype: DataType,
sort_options: Option<SortOptions>,
ignore_nulls: bool,
}

impl DistinctArrayAggAccumulator {
pub fn try_new(
datatype: &DataType,
sort_options: Option<SortOptions>,
ignore_nulls: bool,
) -> Result<Self> {
Ok(Self {
values: HashSet::new(),
datatype: datatype.clone(),
sort_options,
ignore_nulls,
})
}
}
Expand All @@ -385,11 +413,20 @@ impl Accumulator for DistinctArrayAggAccumulator {
return Ok(());
}

let array = &values[0];
let val = &values[0];
let nulls = if self.ignore_nulls {
val.logical_nulls()
} else {
None
};

for i in 0..array.len() {
let scalar = ScalarValue::try_from_array(&array, i)?;
self.values.insert(scalar);
let nulls = nulls.as_ref();
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
for i in 0..val.len() {
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
self.values.insert(ScalarValue::try_from_array(val, i)?);
}
}
}

Ok(())
Expand Down Expand Up @@ -471,6 +508,8 @@ pub(crate) struct OrderSensitiveArrayAggAccumulator {
ordering_req: LexOrdering,
/// Whether the aggregation is running in reverse.
reverse: bool,
/// Whether the aggregation should ignore null values.
ignore_nulls: bool,
}

impl OrderSensitiveArrayAggAccumulator {
Expand All @@ -481,6 +520,7 @@ impl OrderSensitiveArrayAggAccumulator {
ordering_dtypes: &[DataType],
ordering_req: LexOrdering,
reverse: bool,
ignore_nulls: bool,
) -> Result<Self> {
let mut datatypes = vec![datatype.clone()];
datatypes.extend(ordering_dtypes.iter().cloned());
Expand All @@ -490,6 +530,7 @@ impl OrderSensitiveArrayAggAccumulator {
datatypes,
ordering_req,
reverse,
ignore_nulls,
})
}
}
Expand All @@ -500,11 +541,22 @@ impl Accumulator for OrderSensitiveArrayAggAccumulator {
return Ok(());
}

let n_row = values[0].len();
for index in 0..n_row {
let row = get_row_at_idx(values, index)?;
self.values.push(row[0].clone());
self.ordering_values.push(row[1..].to_vec());
let val = &values[0];
let ord = &values[1..];
let nulls = if self.ignore_nulls {
val.logical_nulls()
} else {
None
};

let nulls = nulls.as_ref();
if nulls.is_none_or(|nulls| nulls.null_count() < val.len()) {
for i in 0..val.len() {
if nulls.is_none_or(|nulls| nulls.is_valid(i)) {
self.values.push(ScalarValue::try_from_array(val, i)?);
self.ordering_values.push(get_row_at_idx(ord, i)?)
}
}
}

Ok(())
Expand Down Expand Up @@ -665,7 +717,7 @@ impl OrderSensitiveArrayAggAccumulator {
#[cfg(test)]
mod tests {
use super::*;
use arrow::datatypes::{FieldRef, Schema};
use arrow::datatypes::Schema;
use datafusion_common::cast::as_generic_string_array;
use datafusion_common::internal_err;
use datafusion_physical_expr::expressions::Column;
Expand Down Expand Up @@ -946,14 +998,12 @@ mod tests {
fn new(data_type: DataType) -> Self {
Self {
data_type: data_type.clone(),
distinct: Default::default(),
distinct: false,
ordering: Default::default(),
schema: Schema {
fields: Fields::from(vec![Field::new(
"col",
DataType::List(FieldRef::new(Field::new(
"item", data_type, true,
))),
DataType::new_list(data_type, true),
true,
)]),
metadata: Default::default(),
Expand Down
35 changes: 32 additions & 3 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -289,17 +289,19 @@ CREATE TABLE array_agg_distinct_list_table AS VALUES
('b', [1,0]),
('b', [1,0]),
('b', [1,0]),
('b', [0,1])
('b', [0,1]),
(NULL, [0,1]),
('b', NULL)
;

# Apply array_sort to have deterministic result, higher dimension nested array also works but not for array sort,
# so they are covered in `datafusion/functions-aggregate/src/array_agg.rs`
query ??
select array_sort(c1), array_sort(c2) from (
select array_agg(distinct column1) as c1, array_agg(distinct column2) as c2 from array_agg_distinct_list_table
select array_agg(distinct column1) as c1, array_agg(distinct column2) ignore nulls as c2 from array_agg_distinct_list_table
);
----
[b, w] [[0, 1], [1, 0]]
[NULL, b, w] [[0, 1], [1, 0]]

statement ok
drop table array_agg_distinct_list_table;
Expand Down Expand Up @@ -3194,6 +3196,33 @@ select array_agg(column1) from t;
statement ok
drop table t;

# array_agg_ignore_nulls
statement ok
create table t as values (NULL, ''), (1, 'c'), (2, 'a'), (NULL, 'b'), (4, NULL), (NULL, NULL), (5, 'a');

query ?
select array_agg(column1) ignore nulls as c1 from t;
----
[1, 2, 4, 5]

query II
select count(*), array_length(array_agg(distinct column2) ignore nulls) from t;
----
7 4

query ?
select array_agg(column2 order by column1) ignore nulls from t;
----
[c, a, a, , b]

query ?
select array_agg(DISTINCT column2 order by column2) ignore nulls from t;
----
[, a, b, c]

statement ok
drop table t;

# variance_single_value
query RRRR
select var(sq.column1), var_pop(sq.column1), stddev(sq.column1), stddev_pop(sq.column1) from (values (1.0)) as sq;
Expand Down