Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for order sensitive NTH_VALUE aggregation, make reverse ARRAY_AGG more efficient #8841

Merged
merged 32 commits into from
Jan 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
a99de23
Initial commit
mustafasrepo Jan 10, 2024
da9adee
minor changes
mustafasrepo Jan 10, 2024
4c033ca
Parse index argument
mustafasrepo Jan 10, 2024
97019c0
Move nth_value to array_agg
mustafasrepo Jan 10, 2024
b5672fa
Initial implementation (with redundant data)
mustafasrepo Jan 10, 2024
73ec593
Add new test
mustafasrepo Jan 10, 2024
b04fdef
Add reverse support
mustafasrepo Jan 10, 2024
151f3d3
Add new slt tests
mustafasrepo Jan 10, 2024
80f5c41
Add multi partition support
mustafasrepo Jan 10, 2024
ec1847f
Minor changes
mustafasrepo Jan 10, 2024
2b9ab95
Minor changes
mustafasrepo Jan 10, 2024
863d4bb
Add new aggregator to the proto
mustafasrepo Jan 10, 2024
e22f456
Remove redundant tests
mustafasrepo Jan 10, 2024
22ff28e
Keep n entries in the state for nth value
mustafasrepo Jan 11, 2024
d542b36
Change implementation
mustafasrepo Jan 11, 2024
5cedec5
Move nth value to its own file
mustafasrepo Jan 11, 2024
24a0225
Minor changes
mustafasrepo Jan 11, 2024
3ee2b55
minor changes
mustafasrepo Jan 11, 2024
a136973
Review
ozankabak Jan 12, 2024
6a23dd2
Update comments
mustafasrepo Jan 12, 2024
fa70997
Use drain method to remove from the beginning.
mustafasrepo Jan 12, 2024
2275582
Add reverse support, convert buffer to vecdeque
mustafasrepo Jan 15, 2024
7138d7f
Minor changes
mustafasrepo Jan 15, 2024
067a295
Merge branch 'apache_main' into feature/nth_value_agg
mustafasrepo Jan 15, 2024
ebf5c2c
Minor changes
mustafasrepo Jan 15, 2024
4cab3db
Review Part 2
ozankabak Jan 15, 2024
3ce91e2
Review Part 3
ozankabak Jan 15, 2024
d1250bf
Add new_list from iter
mustafasrepo Jan 16, 2024
7684107
Convert API to receive vecdeque
mustafasrepo Jan 16, 2024
137c21e
Receive mutable argument
mustafasrepo Jan 16, 2024
864ee9f
Refactor merge implementation
mustafasrepo Jan 16, 2024
1fad8f9
Fix doctest
mustafasrepo Jan 16, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 65 additions & 18 deletions datafusion/common/src/scalar.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@

use std::borrow::Borrow;
use std::cmp::Ordering;
use std::collections::HashSet;
use std::convert::{Infallible, TryInto};
use std::collections::{HashSet, VecDeque};
use std::convert::{Infallible, TryFrom, TryInto};
use std::fmt;
use std::hash::Hash;
use std::iter::repeat;
use std::str::FromStr;
use std::{convert::TryFrom, fmt, iter::repeat, sync::Arc};
use std::sync::Arc;

use crate::arrow_datafusion_err;
use crate::cast::{
Expand All @@ -33,23 +35,22 @@ use crate::cast::{
use crate::error::{DataFusionError, Result, _internal_err, _not_impl_err};
use crate::hash_utils::create_hashes;
use crate::utils::{array_into_large_list_array, array_into_list_array};

use arrow::compute::kernels::numeric::*;
use arrow::datatypes::{i256, Fields, SchemaBuilder};
use arrow::util::display::{ArrayFormatter, FormatOptions};
use arrow::{
array::*,
compute::kernels::cast::{cast_with_options, CastOptions},
datatypes::{
ArrowDictionaryKeyType, ArrowNativeType, DataType, Field, Float32Type, Int16Type,
Int32Type, Int64Type, Int8Type, IntervalDayTimeType, IntervalMonthDayNanoType,
IntervalUnit, IntervalYearMonthType, TimeUnit, TimestampMicrosecondType,
i256, ArrowDictionaryKeyType, ArrowNativeType, ArrowTimestampType, DataType,
Field, Fields, Float32Type, Int16Type, Int32Type, Int64Type, Int8Type,
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit,
IntervalYearMonthType, SchemaBuilder, TimeUnit, TimestampMicrosecondType,
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
UInt16Type, UInt32Type, UInt64Type, UInt8Type, DECIMAL128_MAX_PRECISION,
},
};
use arrow_array::cast::as_list_array;
use arrow_array::types::ArrowTimestampType;
use arrow_array::{ArrowNativeTypeOp, Scalar};

/// A dynamically typed, nullable single value, (the single-valued counter-part
/// to arrow's [`Array`])
Expand Down Expand Up @@ -1729,6 +1730,43 @@ impl ScalarValue {
Arc::new(array_into_list_array(values))
}

/// Converts `IntoIterator<Item = ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`ListArray`].
///
/// Example
/// ```
/// use datafusion_common::ScalarValue;
/// use arrow::array::{ListArray, Int32Array};
/// use arrow::datatypes::{DataType, Int32Type};
/// use datafusion_common::cast::as_list_array;
///
/// let scalars = vec![
/// ScalarValue::Int32(Some(1)),
/// ScalarValue::Int32(None),
/// ScalarValue::Int32(Some(2))
/// ];
///
/// let result = ScalarValue::new_list_from_iter(scalars.into_iter(), &DataType::Int32);
///
/// let expected = ListArray::from_iter_primitive::<Int32Type, _, _>(
/// vec![
/// Some(vec![Some(1), None, Some(2)])
/// ]);
///
/// assert_eq!(*result, expected);
/// ```
pub fn new_list_from_iter(
values: impl IntoIterator<Item = ScalarValue> + ExactSizeIterator,
data_type: &DataType,
) -> Arc<ListArray> {
let values = if values.len() == 0 {
new_empty_array(data_type)
} else {
Self::iter_to_array(values).unwrap()
};
Arc::new(array_into_list_array(values))
}

/// Converts `Vec<ScalarValue>` where each element has type corresponding to
/// `data_type`, to a [`LargeListArray`].
///
Expand Down Expand Up @@ -2627,6 +2665,18 @@ impl ScalarValue {
.sum::<usize>()
}

/// Estimates [size](Self::size) of [`VecDeque`] in bytes.
///
/// Includes the size of the [`VecDeque`] container itself.
pub fn size_of_vec_deque(vec_deque: &VecDeque<Self>) -> usize {
std::mem::size_of_val(vec_deque)
+ (std::mem::size_of::<ScalarValue>() * vec_deque.capacity())
+ vec_deque
.iter()
.map(|sv| sv.size() - std::mem::size_of_val(sv))
.sum::<usize>()
}

/// Estimates [size](Self::size) of [`HashSet`] in bytes.
///
/// Includes the size of the [`HashSet`] container itself.
Expand Down Expand Up @@ -3152,22 +3202,19 @@ impl ScalarType<i64> for TimestampNanosecondType {

#[cfg(test)]
mod tests {
use super::*;

use std::cmp::Ordering;
use std::sync::Arc;

use chrono::NaiveDate;
use rand::Rng;
use super::*;
use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};

use arrow::buffer::OffsetBuffer;
use arrow::compute::kernels;
use arrow::compute::{concat, is_null};
use arrow::datatypes::ArrowPrimitiveType;
use arrow::compute::{concat, is_null, kernels};
use arrow::datatypes::{ArrowNumericType, ArrowPrimitiveType};
use arrow::util::pretty::pretty_format_columns;
use arrow_array::ArrowNumericType;

use crate::cast::{as_string_array, as_uint32_array, as_uint64_array};
use chrono::NaiveDate;
use rand::Rng;

#[test]
fn test_to_array_of_size_for_list() {
Expand Down
41 changes: 25 additions & 16 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,39 +17,44 @@

//! Aggregate function module contains all built-in aggregate functions definitions

use std::sync::Arc;
use std::{fmt, str::FromStr};

use crate::utils;
use crate::{type_coercion::aggregates::*, Signature, TypeSignature, Volatility};

use arrow::datatypes::{DataType, Field};
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};
use std::sync::Arc;
use std::{fmt, str::FromStr};

use strum_macros::EnumIter;

/// Enum of all built-in aggregate functions
// Contributor's guide for adding new aggregate functions
// https://arrow.apache.org/datafusion/contributor-guide/index.html#how-to-add-a-new-aggregate-function
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash, EnumIter)]
pub enum AggregateFunction {
/// count
/// Count
Count,
/// sum
/// Sum
Sum,
/// min
/// Minimum
Min,
/// max
/// Maximum
Max,
/// avg
/// Average
Avg,
/// median
/// Median
Median,
/// Approximate aggregate function
/// Approximate distinct function
ApproxDistinct,
/// array_agg
/// Aggregation into an array
ArrayAgg,
/// first_value
/// First value in a group according to some ordering
FirstValue,
/// last_value
/// Last value in a group according to some ordering
LastValue,
/// N'th value in a group according to some ordering
NthValue,
/// Variance (Sample)
Variance,
/// Variance (Population)
Expand Down Expand Up @@ -100,7 +105,7 @@ pub enum AggregateFunction {
BoolAnd,
/// Bool Or
BoolOr,
/// string_agg
/// String aggregation
StringAgg,
}

Expand All @@ -118,6 +123,7 @@ impl AggregateFunction {
ArrayAgg => "ARRAY_AGG",
FirstValue => "FIRST_VALUE",
LastValue => "LAST_VALUE",
NthValue => "NTH_VALUE",
Variance => "VAR",
VariancePop => "VAR_POP",
Stddev => "STDDEV",
Expand Down Expand Up @@ -174,6 +180,7 @@ impl FromStr for AggregateFunction {
"array_agg" => AggregateFunction::ArrayAgg,
"first_value" => AggregateFunction::FirstValue,
"last_value" => AggregateFunction::LastValue,
"nth_value" => AggregateFunction::NthValue,
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
Expand Down Expand Up @@ -300,9 +307,9 @@ impl AggregateFunction {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::Grouping => Ok(DataType::Int32),
AggregateFunction::FirstValue | AggregateFunction::LastValue => {
Ok(coerced_data_types[0].clone())
}
AggregateFunction::FirstValue
| AggregateFunction::LastValue
| AggregateFunction::NthValue => Ok(coerced_data_types[0].clone()),
AggregateFunction::StringAgg => Ok(DataType::LargeUtf8),
}
}
Expand Down Expand Up @@ -371,6 +378,7 @@ impl AggregateFunction {
| AggregateFunction::LastValue => {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Covariance
| AggregateFunction::CovariancePop
| AggregateFunction::Correlation
Expand Down Expand Up @@ -428,6 +436,7 @@ impl AggregateFunction {
#[cfg(test)]
mod tests {
use super::*;

use strum::IntoEnumIterator;

#[test]
Expand Down
13 changes: 7 additions & 6 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,16 @@
// specific language governing permissions and limitations
// under the License.

use std::ops::Deref;

use super::functions::can_coerce_from;
use crate::{AggregateFunction, Signature, TypeSignature};

use arrow::datatypes::{
DataType, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE,
};

use datafusion_common::{internal_err, plan_err, DataFusionError, Result};
use std::ops::Deref;

use crate::{AggregateFunction, Signature, TypeSignature};

use super::functions::can_coerce_from;

pub static STRINGS: &[DataType] = &[DataType::Utf8, DataType::LargeUtf8];

Expand Down Expand Up @@ -297,6 +296,7 @@ pub fn coerce_types(
AggregateFunction::Median
| AggregateFunction::FirstValue
| AggregateFunction::LastValue => Ok(input_types.to_vec()),
AggregateFunction::NthValue => Ok(input_types.to_vec()),
AggregateFunction::Grouping => Ok(vec![input_types[0].clone()]),
AggregateFunction::StringAgg => {
if !is_string_agg_supported_arg_type(&input_types[0]) {
Expand Down Expand Up @@ -584,6 +584,7 @@ pub fn is_string_agg_supported_arg_type(arg_type: &DataType) -> bool {
#[cfg(test)]
mod tests {
use super::*;

use arrow::datatypes::DataType;

#[test]
Expand Down
Loading