Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions datafusion/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1027,8 +1027,8 @@ doc_comment::doctest!(

#[cfg(doctest)]
doc_comment::doctest!(
"../../../docs/source/user-guide/sql/write_options.md",
user_guide_sql_write_options
"../../../docs/source/user-guide/sql/format_options.md",
user_guide_sql_format_options
);

#[cfg(doctest)]
Expand Down
4 changes: 2 additions & 2 deletions datafusion/core/tests/expr_api/simplification.rs
Original file line number Diff line number Diff line change
Expand Up @@ -547,9 +547,9 @@ fn test_simplify_with_cycle_count(
};
let simplifier = ExprSimplifier::new(info);
let (simplified_expr, count) = simplifier
.simplify_with_cycle_count(input_expr.clone())
.simplify_with_cycle_count_transformed(input_expr.clone())
.expect("successfully evaluated");

let simplified_expr = simplified_expr.data;
assert_eq!(
simplified_expr, expected_expr,
"Mismatch evaluating {input_expr}\n Expected:{expected_expr}\n Got:{simplified_expr}"
Expand Down
3 changes: 3 additions & 0 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType>
let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal256(new_precision, new_scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_return_type(func_name, dict_value_type.as_ref())
Expand All @@ -231,6 +232,7 @@ pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_sum_type(dict_value_type.as_ref())
Expand Down Expand Up @@ -298,6 +300,7 @@ pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<Da
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
d if d.is_numeric() => Ok(DataType::Float64),
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()),
_ => {
plan_err!(
Expand Down
114 changes: 112 additions & 2 deletions datafusion/functions-aggregate/src/average.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@ use arrow::array::{

use arrow::compute::sum;
use arrow::datatypes::{
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType, Field,
Float64Type, UInt64Type,
i256, ArrowNativeType, DataType, Decimal128Type, Decimal256Type, DecimalType,
DurationMicrosecondType, DurationMillisecondType, DurationNanosecondType,
DurationSecondType, Field, Float64Type, TimeUnit, UInt64Type,
};
use datafusion_common::{
exec_err, not_impl_err, utils::take_function_args, Result, ScalarValue,
Expand Down Expand Up @@ -145,6 +146,16 @@ impl AggregateUDFImpl for Avg {
target_precision: *target_precision,
target_scale: *target_scale,
})),

(Duration(time_unit), Duration(result_unit)) => {
Ok(Box::new(DurationAvgAccumulator {
sum: None,
count: 0,
time_unit: *time_unit,
result_unit: *result_unit,
}))
}

_ => exec_err!(
"AvgAccumulator for ({} --> {})",
&data_type,
Expand Down Expand Up @@ -399,6 +410,105 @@ impl<T: DecimalType + ArrowNumericType + Debug> Accumulator for DecimalAvgAccumu
}
}

/// An accumulator to compute the average for duration values
#[derive(Debug)]
struct DurationAvgAccumulator {
sum: Option<i64>,
count: u64,
time_unit: TimeUnit,
result_unit: TimeUnit,
}

impl Accumulator for DurationAvgAccumulator {
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count += (array.len() - array.null_count()) as u64;

let sum_value = match self.time_unit {
TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
};

if let Some(x) = sum_value {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}

fn evaluate(&mut self) -> Result<ScalarValue> {
let avg = self.sum.map(|sum| sum / self.count as i64);

match self.result_unit {
TimeUnit::Second => Ok(ScalarValue::DurationSecond(avg)),
TimeUnit::Millisecond => Ok(ScalarValue::DurationMillisecond(avg)),
TimeUnit::Microsecond => Ok(ScalarValue::DurationMicrosecond(avg)),
TimeUnit::Nanosecond => Ok(ScalarValue::DurationNanosecond(avg)),
}
}

fn size(&self) -> usize {
size_of_val(self)
}

fn state(&mut self) -> Result<Vec<ScalarValue>> {
let duration_value = match self.time_unit {
TimeUnit::Second => ScalarValue::DurationSecond(self.sum),
TimeUnit::Millisecond => ScalarValue::DurationMillisecond(self.sum),
TimeUnit::Microsecond => ScalarValue::DurationMicrosecond(self.sum),
TimeUnit::Nanosecond => ScalarValue::DurationNanosecond(self.sum),
};

Ok(vec![ScalarValue::from(self.count), duration_value])
}

fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
self.count += sum(states[0].as_primitive::<UInt64Type>()).unwrap_or_default();

let sum_value = match self.time_unit {
TimeUnit::Second => sum(states[1].as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => {
sum(states[1].as_primitive::<DurationMillisecondType>())
}
TimeUnit::Microsecond => {
sum(states[1].as_primitive::<DurationMicrosecondType>())
}
TimeUnit::Nanosecond => {
sum(states[1].as_primitive::<DurationNanosecondType>())
}
};

if let Some(x) = sum_value {
let v = self.sum.get_or_insert(0);
*v += x;
}
Ok(())
}

fn retract_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
let array = &values[0];
self.count -= (array.len() - array.null_count()) as u64;

let sum_value = match self.time_unit {
TimeUnit::Second => sum(array.as_primitive::<DurationSecondType>()),
TimeUnit::Millisecond => sum(array.as_primitive::<DurationMillisecondType>()),
TimeUnit::Microsecond => sum(array.as_primitive::<DurationMicrosecondType>()),
TimeUnit::Nanosecond => sum(array.as_primitive::<DurationNanosecondType>()),
};

if let Some(x) = sum_value {
self.sum = Some(self.sum.unwrap() - x);
}
Ok(())
}

fn supports_retract_batch(&self) -> bool {
true
}
}

/// An accumulator to compute the average of `[PrimitiveArray<T>]`.
/// Stores values as native types, and does overflow checking
///
Expand Down
48 changes: 41 additions & 7 deletions datafusion/optimizer/src/simplify_expressions/expr_simplifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// assert_eq!(expr, b_lt_2);
/// ```
pub fn simplify(&self, expr: Expr) -> Result<Expr> {
Ok(self.simplify_with_cycle_count(expr)?.0)
Ok(self.simplify_with_cycle_count_transformed(expr)?.0.data)
}

/// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
Expand All @@ -198,7 +198,34 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
///
/// See [Self::simplify] for details and usage examples.
///
#[deprecated(
since = "48.0.0",
note = "Use `simplify_with_cycle_count_transformed` instead"
)]
#[allow(unused_mut)]
pub fn simplify_with_cycle_count(&self, mut expr: Expr) -> Result<(Expr, u32)> {
let (transformed, cycle_count) =
self.simplify_with_cycle_count_transformed(expr)?;
Ok((transformed.data, cycle_count))
}

/// Like [Self::simplify], simplifies this [`Expr`] as much as possible, evaluating
/// constants and applying algebraic simplifications. Additionally returns a `u32`
/// representing the number of simplification cycles performed, which can be useful for testing
/// optimizations.
///
/// # Returns
///
/// A tuple containing:
/// - The simplified expression wrapped in a `Transformed<Expr>` indicating if changes were made
/// - The number of simplification cycles that were performed
///
/// See [Self::simplify] for details and usage examples.
///
pub fn simplify_with_cycle_count_transformed(
&self,
mut expr: Expr,
) -> Result<(Transformed<Expr>, u32)> {
let mut simplifier = Simplifier::new(&self.info);
let mut const_evaluator = ConstEvaluator::try_new(self.info.execution_props())?;
let mut shorten_in_list_simplifier = ShortenInListSimplifier::new();
Expand All @@ -212,6 +239,7 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
// simplifications can enable new constant evaluation
// see `Self::with_max_cycles`
let mut num_cycles = 0;
let mut has_transformed = false;
loop {
let Transformed {
data, transformed, ..
Expand All @@ -221,13 +249,18 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
.transform_data(|expr| expr.rewrite(&mut guarantee_rewriter))?;
expr = data;
num_cycles += 1;
// Track if any transformation occurred
has_transformed = has_transformed || transformed;
if !transformed || num_cycles >= self.max_simplifier_cycles {
break;
}
}
// shorten inlist should be started after other inlist rules are applied
expr = expr.rewrite(&mut shorten_in_list_simplifier).data()?;
Ok((expr, num_cycles))
Ok((
Transformed::new_transformed(expr, has_transformed),
num_cycles,
))
}

/// Apply type coercion to an [`Expr`] so that it can be
Expand Down Expand Up @@ -392,15 +425,15 @@ impl<S: SimplifyInfo> ExprSimplifier<S> {
/// let expr = col("a").is_not_null();
///
/// // When using default maximum cycles, 2 cycles will be performed.
/// let (simplified_expr, count) = simplifier.simplify_with_cycle_count(expr.clone()).unwrap();
/// assert_eq!(simplified_expr, lit(true));
/// let (simplified_expr, count) = simplifier.simplify_with_cycle_count_transformed(expr.clone()).unwrap();
/// assert_eq!(simplified_expr.data, lit(true));
/// // 2 cycles were executed, but only 1 was needed
/// assert_eq!(count, 2);
///
/// // Only 1 simplification pass is necessary here, so we can set the maximum cycles to 1.
/// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count(expr.clone()).unwrap();
/// let (simplified_expr, count) = simplifier.with_max_cycles(1).simplify_with_cycle_count_transformed(expr.clone()).unwrap();
/// // Expression has been rewritten to: (c = a AND b = 1)
/// assert_eq!(simplified_expr, lit(true));
/// assert_eq!(simplified_expr.data, lit(true));
/// // Only 1 cycle was executed
/// assert_eq!(count, 1);
///
Expand Down Expand Up @@ -3329,7 +3362,8 @@ mod tests {
let simplifier = ExprSimplifier::new(
SimplifyContext::new(&execution_props).with_schema(schema),
);
simplifier.simplify_with_cycle_count(expr)
let (expr, count) = simplifier.simplify_with_cycle_count_transformed(expr)?;
Ok((expr.data, count))
}

fn simplify_with_cycle_count(expr: Expr) -> (Expr, u32) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,10 +123,11 @@ impl SimplifyExpressions {
let name_preserver = NamePreserver::new(&plan);
let mut rewrite_expr = |expr: Expr| {
let name = name_preserver.save(&expr);
let expr = simplifier.simplify(expr)?;
// TODO it would be nice to have a way to know if the expression was simplified
// or not. For now conservatively return Transformed::yes
Ok(Transformed::yes(name.restore(expr)))
let expr = simplifier.simplify_with_cycle_count_transformed(expr)?.0;
Ok(Transformed::new_transformed(
name.restore(expr.data),
expr.transformed,
))
};

plan.map_expressions(|expr| {
Expand Down
66 changes: 66 additions & 0 deletions datafusion/sqllogictest/test_files/aggregate.slt
Original file line number Diff line number Diff line change
Expand Up @@ -4969,6 +4969,72 @@ select count(distinct column1), count(distinct column2) from dict_test group by
statement ok
drop table dict_test;

# avg_duration

statement ok
create table d as values
(arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1),
(arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1);

query ????
SELECT avg(column1), avg(column2), avg(column3), avg(column4) FROM d;
----
0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs

query ????I
SELECT avg(column1), avg(column2), avg(column3), avg(column4), column5 FROM d GROUP BY column5;
----
0 days 0 hours 0 mins 6 secs 0 days 0 hours 0 mins 0.012 secs 0 days 0 hours 0 mins 0.000018 secs 0 days 0 hours 0 mins 0.000000024 secs 1

statement ok
drop table d;

statement ok
create table d as values
(arrow_cast(1, 'Duration(Second)'), arrow_cast(2, 'Duration(Millisecond)'), arrow_cast(3, 'Duration(Microsecond)'), arrow_cast(4, 'Duration(Nanosecond)'), 1),
(arrow_cast(11, 'Duration(Second)'), arrow_cast(22, 'Duration(Millisecond)'), arrow_cast(33, 'Duration(Microsecond)'), arrow_cast(44, 'Duration(Nanosecond)'), 1),
(arrow_cast(5, 'Duration(Second)'), arrow_cast(10, 'Duration(Millisecond)'), arrow_cast(15, 'Duration(Microsecond)'), arrow_cast(20, 'Duration(Nanosecond)'), 2),
(arrow_cast(25, 'Duration(Second)'), arrow_cast(50, 'Duration(Millisecond)'), arrow_cast(75, 'Duration(Microsecond)'), arrow_cast(100, 'Duration(Nanosecond)'), 2),
(NULL, NULL, NULL, NULL, 1),
(NULL, NULL, NULL, NULL, 2);

query I?
SELECT column5, avg(column1) FROM d GROUP BY column5;
----
2 0 days 0 hours 0 mins 15 secs
1 0 days 0 hours 0 mins 6 secs

query I??
SELECT column5, column1, avg(column1) OVER (PARTITION BY column5 ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND CURRENT ROW) as window_avg
FROM d WHERE column1 IS NOT NULL;
----
2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs
2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs
1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs
1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs

# Cumulative average window function
query I??
SELECT column5, column1, avg(column1) OVER (ORDER BY column5, column1 ROWS BETWEEN UNBOUNDED PRECEDING AND CURRENT ROW) as cumulative_avg
FROM d WHERE column1 IS NOT NULL;
----
1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 1 secs
1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 6 secs
2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 5 secs
2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 10 secs

# Centered average window function
query I??
SELECT column5, column1, avg(column1) OVER (ORDER BY column5 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) as centered_avg
FROM d WHERE column1 IS NOT NULL;
----
1 0 days 0 hours 0 mins 1 secs 0 days 0 hours 0 mins 6 secs
1 0 days 0 hours 0 mins 11 secs 0 days 0 hours 0 mins 5 secs
2 0 days 0 hours 0 mins 5 secs 0 days 0 hours 0 mins 13 secs
2 0 days 0 hours 0 mins 25 secs 0 days 0 hours 0 mins 15 secs

statement ok
drop table d;

# Prepare the table with dictionary values for testing
statement ok
Expand Down
2 changes: 1 addition & 1 deletion docs/source/user-guide/sql/ddl.md
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ LOCATION <literal>
<key_value_list> := (<literal> <literal, <literal> <literal>, ...)
```

For a detailed list of write related options which can be passed in the OPTIONS key_value_list, see [Write Options](write_options).
For a comprehensive list of format-specific options that can be specified in the `OPTIONS` clause, see [Format Options](format_options.md).

`file_type` is one of `CSV`, `ARROW`, `PARQUET`, `AVRO` or `JSON`

Expand Down
2 changes: 1 addition & 1 deletion docs/source/user-guide/sql/dml.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ The output format is determined by the first match of the following rules:
1. Value of `STORED AS`
2. Filename extension (e.g. `foo.parquet` implies `PARQUET` format)

For a detailed list of valid OPTIONS, see [Write Options](write_options).
For a detailed list of valid OPTIONS, see [Format Options](format_options.md).

### Examples

Expand Down
Loading
Loading