Skip to content

Support Accumulator for avg duration #15468

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Apr 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal256(new_precision, new_scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_return_type(func_name, dict_value_type.as_ref())
Expand All @@ -231,6 +232,7 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_sum_type(dict_value_type.as_ref())
Expand Down Expand Up @@ -298,6 +300,7 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
d if d.is_numeric() => Ok(DataType::Float64),
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()),
_ => {
plan_err!(
Expand Down
114 changes: 112 additions & 2 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use arrow::array::{

use arrow::compute::sum;
use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Float64Type, UInt64Type,
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
DurationSecondType, Field, Float64Type, TimeUnit, UInt64Type,
};
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
Expand Down Expand Up @@ -145,6 +146,16 @@ impl AggregateUDFImpl for Avg {
target_precision: *target_precision,
target_scale: *target_scale,
})),

(Duration(time_unit), Duration(result_unit)) => {
Ok(Box::new(DurationAvgAccumulator {
sum: None,
count: 0,
time_unit: *time_unit,
result_unit: *result_unit,
}))
}

_ => exec_err!(
"AvgAccumulator for ({} --> {})",
&data_type,
Expand Down Expand Up @@ -399,6 +410,105 @@ impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumu
}
}

/// An accumulator to compute the average for duration values
#[derive(Debug)]
struct DurationAvgAccumulator {
sum: Option<i64>,
count: u64,
time_unit: TimeUnit,
result_unit: TimeUnit,
}

impl Accumulator for DurationAvgAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count += (array.len() - array.null_count()) as u64;

let sum_value = match self.time_unit {
TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
};

if let Some(x) = sum_value {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let avg = self.sum.map(|sum| sum / self.count as i64);

match self.result_unit {
TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
}
}

fn size(&self) -> usize {
size_of_val(self)
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
let duration_value = match self.time_unit {
TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
};

Ok(vec![ScalarValue::from(self.count), duration_value])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();

let sum_value = match self.time_unit {
TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => {
sum(states[1].as_primitive::<DurationMillisecondType>())
}
TimeUnit::Microsecond => {
sum(states[1].as_primitive::<DurationMicrosecondType>())
}
TimeUnit::Nanosecond => {
sum(states[1].as_primitive::<DurationNanosecondType>())
}
};

if let Some(x) = sum_value {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count -= (array.len() - array.null_count()) as u64;

let sum_value = match self.time_unit {
TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
};

if let Some(x) = sum_value {
self.sum = Some(self.sum.unwrap() - x);
}
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
/// Stores values as native types, and does overflow checking
///
Expand Down
66 changes: 66 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4969,6 +4969,72 @@ select count(distinct column1), count(distinct column2) from dict_test group by
statement ok
drop table dict_test;

# avg_duration

statement ok
create table d as values
(arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1),
(arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1);

query ????
SELECT avg(column1), avg(column2), avg(column3), avg(column4) FROM d;
----
0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs

query ????I
SELECT avg(column1), avg(column2), avg(column3), avg(column4), column5 FROM d GROUP BY column5;
----
0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs 1

statement ok
drop table d;

statement ok
create table d as values
(arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1),
(arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1),
(arrow_cast(5, 'Duration(Second)'), arrow_cast(10, 'Duration(Millisecond)'), arrow_cast(15, 'Duration(Microsecond)'), arrow_cast(20, 'Duration(Nanosecond)'), 2),
(arrow_cast(25, 'Duration(Second)'), arrow_cast(50, 'Duration(Millisecond)'), arrow_cast(75, 'Duration(Microsecond)'), arrow_cast(100, 'Duration(Nanosecond)'), 2),
(NULL, NULL, NULL, NULL, 1),
(NULL, NULL, NULL, NULL, 2);

query I?
SELECT column5, avg(column1) FROM d GROUP BY column5;
----
2 0 days 0 hours 0 mins 15 secs
1 0 days 0 hours 0 mins 6 secs

query I??
SELECT column5, column1, avg(column1) OVER (PARTITION BY column5 ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) as window_avg
FROM d WHERE column1 IS NOT NULL;
----
2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs
2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs
1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs
1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs

# Cumulative average window function
query I??
SELECT column5, column1, avg(column1) OVER (ORDER BY column5, column1 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumulative_avg
FROM d WHERE column1 IS NOT NULL;
----
1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs
1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs
2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs
2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 10 secs

# Centered average window function
query I??
SELECT column5, column1, avg(column1) OVER (ORDER BY column5 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as centered_avg
FROM d WHERE column1 IS NOT NULL;
----
1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 6 secs
1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 5 secs
2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 13 secs
2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs

statement ok
drop table d;

# Prepare the table with dictionary values for testing
statement ok
Expand Down