Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
300 changes: 3 additions & 297 deletions datafusion/expr-common/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,31 +16,12 @@
// under the License.

use crate::signature::TypeSignature;
use arrow::datatypes::{
DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION,
DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE,
};
use arrow::datatypes::{DataType, FieldRef};

use datafusion_common::{internal_err, plan_err, Result};

pub static STRINGS: &[DataType] =
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume these were leftover from before moving to the Signature API

&[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View];

pub static SIGNED_INTEGERS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
DataType::Int32,
DataType::Int64,
];

pub static UNSIGNED_INTEGERS: &[DataType] = &[
DataType::UInt8,
DataType::UInt16,
DataType::UInt32,
DataType::UInt64,
];

// TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll raise an issue for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please update the comment with reference to GH issue so it would be easier to pick up if anyone wanna contribute

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done 👍

// see https://github.com/apache/datafusion/issues/18092
pub static INTEGERS: &[DataType] = &[
DataType::Int8,
DataType::Int16,
Expand All @@ -65,24 +46,6 @@ pub static NUMERICS: &[DataType] = &[
DataType::Float64,
];

pub static TIMESTAMPS: &[DataType] = &[
DataType::Timestamp(TimeUnit::Second, None),
DataType::Timestamp(TimeUnit::Millisecond, None),
DataType::Timestamp(TimeUnit::Microsecond, None),
DataType::Timestamp(TimeUnit::Nanosecond, None),
];

pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64];

pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary];

pub static TIMES: &[DataType] = &[
DataType::Time32(TimeUnit::Second),
DataType::Time32(TimeUnit::Millisecond),
DataType::Time64(TimeUnit::Microsecond),
DataType::Time64(TimeUnit::Nanosecond),
];

/// Validate the length of `input_fields` matches the `signature` for `agg_fun`.
///
/// This method DOES NOT validate the argument fields - only that (at least one,
Expand Down Expand Up @@ -144,260 +107,3 @@ pub fn check_arg_count(
}
Ok(())
}

/// Function return type of a sum
pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All of these functions except the avg ones were unused in our code, and I don't think it makes sense to have them available for users anyway

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If users want them, they can copy them into their own code on upgrade

match arg_type {
DataType::Int64 => Ok(DataType::Int64),
DataType::UInt64 => Ok(DataType::UInt64),
DataType::Float64 => Ok(DataType::Float64),
DataType::Decimal32(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal32(new_precision, *scale))
}
DataType::Decimal64(precision, scale) => {
// in the spark, the result type is DECIMAL(min(38,precision+10), s)
// ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal64(new_precision, *scale))
}
DataType::Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+10), s)
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
DataType::Decimal256(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+10), s)
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
other => plan_err!("SUM does not support type \"{other:?}\""),
}
}

/// Function return type of variance
pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("VAR does not support {arg_type}")
}
}

/// Function return type of covariance
pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("COVAR does not support {arg_type}")
}
}

/// Function return type of correlation
pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> {
if NUMERICS.contains(arg_type) {
Ok(DataType::Float64)
} else {
plan_err!("CORR does not support {arg_type}")
}
}

/// Function return type of an average
pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal32(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal32(new_precision, new_scale))
}
DataType::Decimal64(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal64(new_precision, new_scale))
}
DataType::Decimal128(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal128(new_precision, new_scale))
}
DataType::Decimal256(precision, scale) => {
// In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)).
// Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4);
let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4);
Ok(DataType::Decimal256(new_precision, new_scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_return_type(func_name, dict_value_type.as_ref())
}
other => plan_err!("{func_name} does not support {other:?}"),
}
}

/// Internal sum type of an average
pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> {
match arg_type {
DataType::Decimal32(precision, scale) => {
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal32(new_precision, *scale))
}
DataType::Decimal64(precision, scale) => {
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal64(new_precision, *scale))
}
DataType::Decimal128(precision, scale) => {
// In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal128(new_precision, *scale))
}
DataType::Decimal256(precision, scale) => {
// In Spark the sum type of avg is DECIMAL(min(38,precision+10), s)
let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10);
Ok(DataType::Decimal256(new_precision, *scale))
}
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64),
DataType::Dictionary(_, dict_value_type) => {
avg_sum_type(dict_value_type.as_ref())
}
other => plan_err!("AVG does not support {other:?}"),
}
}

pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool {
match arg_type {
DataType::Dictionary(_, dict_value_type) => {
is_sum_support_arg_type(dict_value_type.as_ref())
}
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
),
}
}

pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool {
match arg_type {
DataType::Dictionary(_, dict_value_type) => {
is_avg_support_arg_type(dict_value_type.as_ref())
}
_ => matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
|| matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _))
),
}
}

pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool {
matches!(
arg_type,
arg_type if NUMERICS.contains(arg_type)
)
}

pub fn is_integer_arg_type(arg_type: &DataType) -> bool {
arg_type.is_integer()
}

pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<DataType>> {
// Supported types smallint, int, bigint, real, double precision, decimal, or interval
// Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc
fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> {
match &data_type {
DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)),
DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)),
DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)),
DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)),
d if d.is_numeric() => Ok(DataType::Float64),
DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)),
DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()),
_ => {
plan_err!(
"The function {:?} does not support inputs of type {}.",
func_name,
data_type
)
}
}
}
Ok(vec![coerced_type(func_name, &arg_types[0])?])
}
#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_variance_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = variance_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);

let data_type = DataType::Decimal128(36, 10);
assert!(variance_return_type(&data_type).is_err());
Ok(())
}

#[test]
fn test_sum_return_data_type() -> Result<()> {
let data_type = DataType::Decimal128(10, 5);
let result_type = sum_return_type(&data_type)?;
assert_eq!(DataType::Decimal128(20, 5), result_type);

let data_type = DataType::Decimal128(36, 10);
let result_type = sum_return_type(&data_type)?;
assert_eq!(DataType::Decimal128(38, 10), result_type);
Ok(())
}

#[test]
fn test_covariance_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = covariance_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);

let data_type = DataType::Decimal128(36, 10);
assert!(covariance_return_type(&data_type).is_err());
Ok(())
}

#[test]
fn test_correlation_return_data_type() -> Result<()> {
let data_type = DataType::Float64;
let result_type = correlation_return_type(&data_type)?;
assert_eq!(DataType::Float64, result_type);

let data_type = DataType::Decimal128(36, 10);
assert!(correlation_return_type(&data_type).is_err());
Ok(())
}
}
Loading