-
Notifications
You must be signed in to change notification settings - Fork 1.8k
refactor: remove unused type_coercion/aggregate.rs functions
#18091
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,31 +16,12 @@ | |
| // under the License. | ||
|
|
||
| use crate::signature::TypeSignature; | ||
| use arrow::datatypes::{ | ||
| DataType, FieldRef, TimeUnit, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE, | ||
| DECIMAL256_MAX_PRECISION, DECIMAL256_MAX_SCALE, DECIMAL32_MAX_PRECISION, | ||
| DECIMAL32_MAX_SCALE, DECIMAL64_MAX_PRECISION, DECIMAL64_MAX_SCALE, | ||
| }; | ||
| use arrow::datatypes::{DataType, FieldRef}; | ||
|
|
||
| use datafusion_common::{internal_err, plan_err, Result}; | ||
|
|
||
| pub static STRINGS: &[DataType] = | ||
| &[DataType::Utf8, DataType::LargeUtf8, DataType::Utf8View]; | ||
|
|
||
| pub static SIGNED_INTEGERS: &[DataType] = &[ | ||
| DataType::Int8, | ||
| DataType::Int16, | ||
| DataType::Int32, | ||
| DataType::Int64, | ||
| ]; | ||
|
|
||
| pub static UNSIGNED_INTEGERS: &[DataType] = &[ | ||
| DataType::UInt8, | ||
| DataType::UInt16, | ||
| DataType::UInt32, | ||
| DataType::UInt64, | ||
| ]; | ||
|
|
||
| // TODO: remove usage of these (INTEGERS and NUMERICS) in favour of signatures | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'll raise an issue for this
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please update the comment with reference to GH issue so it would be easier to pick up if anyone wanna contribute
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done 👍 |
||
| // see https://github.com/apache/datafusion/issues/18092 | ||
| pub static INTEGERS: &[DataType] = &[ | ||
| DataType::Int8, | ||
| DataType::Int16, | ||
|
|
@@ -65,24 +46,6 @@ pub static NUMERICS: &[DataType] = &[ | |
| DataType::Float64, | ||
| ]; | ||
|
|
||
| pub static TIMESTAMPS: &[DataType] = &[ | ||
| DataType::Timestamp(TimeUnit::Second, None), | ||
| DataType::Timestamp(TimeUnit::Millisecond, None), | ||
| DataType::Timestamp(TimeUnit::Microsecond, None), | ||
| DataType::Timestamp(TimeUnit::Nanosecond, None), | ||
| ]; | ||
|
|
||
| pub static DATES: &[DataType] = &[DataType::Date32, DataType::Date64]; | ||
|
|
||
| pub static BINARYS: &[DataType] = &[DataType::Binary, DataType::LargeBinary]; | ||
|
|
||
| pub static TIMES: &[DataType] = &[ | ||
| DataType::Time32(TimeUnit::Second), | ||
| DataType::Time32(TimeUnit::Millisecond), | ||
| DataType::Time64(TimeUnit::Microsecond), | ||
| DataType::Time64(TimeUnit::Nanosecond), | ||
| ]; | ||
|
|
||
| /// Validate the length of `input_fields` matches the `signature` for `agg_fun`. | ||
| /// | ||
| /// This method DOES NOT validate the argument fields - only that (at least one, | ||
|
|
@@ -144,260 +107,3 @@ pub fn check_arg_count( | |
| } | ||
| Ok(()) | ||
| } | ||
|
|
||
| /// Function return type of a sum | ||
| pub fn sum_return_type(arg_type: &DataType) -> Result<DataType> { | ||
|
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. All of these functions except the avg ones were unused in our code, and I don't think it makes sense to have them available for users anyway
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If users want them, they can copy them into their own code on upgrade |
||
| match arg_type { | ||
| DataType::Int64 => Ok(DataType::Int64), | ||
| DataType::UInt64 => Ok(DataType::UInt64), | ||
| DataType::Float64 => Ok(DataType::Float64), | ||
| DataType::Decimal32(precision, scale) => { | ||
| // in the spark, the result type is DECIMAL(min(38,precision+10), s) | ||
| // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 | ||
| let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal32(new_precision, *scale)) | ||
| } | ||
| DataType::Decimal64(precision, scale) => { | ||
| // in the spark, the result type is DECIMAL(min(38,precision+10), s) | ||
| // ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 | ||
| let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal64(new_precision, *scale)) | ||
| } | ||
| DataType::Decimal128(precision, scale) => { | ||
| // In the spark, the result type is DECIMAL(min(38,precision+10), s) | ||
| // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 | ||
| let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal128(new_precision, *scale)) | ||
| } | ||
| DataType::Decimal256(precision, scale) => { | ||
| // In the spark, the result type is DECIMAL(min(38,precision+10), s) | ||
| // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Sum.scala#L66 | ||
| let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal256(new_precision, *scale)) | ||
| } | ||
| other => plan_err!("SUM does not support type \"{other:?}\""), | ||
| } | ||
| } | ||
|
|
||
| /// Function return type of variance | ||
| pub fn variance_return_type(arg_type: &DataType) -> Result<DataType> { | ||
| if NUMERICS.contains(arg_type) { | ||
| Ok(DataType::Float64) | ||
| } else { | ||
| plan_err!("VAR does not support {arg_type}") | ||
| } | ||
| } | ||
|
|
||
| /// Function return type of covariance | ||
| pub fn covariance_return_type(arg_type: &DataType) -> Result<DataType> { | ||
| if NUMERICS.contains(arg_type) { | ||
| Ok(DataType::Float64) | ||
| } else { | ||
| plan_err!("COVAR does not support {arg_type}") | ||
| } | ||
| } | ||
|
|
||
| /// Function return type of correlation | ||
| pub fn correlation_return_type(arg_type: &DataType) -> Result<DataType> { | ||
| if NUMERICS.contains(arg_type) { | ||
| Ok(DataType::Float64) | ||
| } else { | ||
| plan_err!("CORR does not support {arg_type}") | ||
| } | ||
| } | ||
|
|
||
| /// Function return type of an average | ||
| pub fn avg_return_type(func_name: &str, arg_type: &DataType) -> Result<DataType> { | ||
| match arg_type { | ||
| DataType::Decimal32(precision, scale) => { | ||
| // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). | ||
| // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 | ||
| let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 4); | ||
| let new_scale = DECIMAL32_MAX_SCALE.min(*scale + 4); | ||
| Ok(DataType::Decimal32(new_precision, new_scale)) | ||
| } | ||
| DataType::Decimal64(precision, scale) => { | ||
| // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). | ||
| // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 | ||
| let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 4); | ||
| let new_scale = DECIMAL64_MAX_SCALE.min(*scale + 4); | ||
| Ok(DataType::Decimal64(new_precision, new_scale)) | ||
| } | ||
| DataType::Decimal128(precision, scale) => { | ||
| // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). | ||
| // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 | ||
| let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 4); | ||
| let new_scale = DECIMAL128_MAX_SCALE.min(*scale + 4); | ||
| Ok(DataType::Decimal128(new_precision, new_scale)) | ||
| } | ||
| DataType::Decimal256(precision, scale) => { | ||
| // In the spark, the result type is DECIMAL(min(38,precision+4), min(38,scale+4)). | ||
| // Ref: https://github.com/apache/spark/blob/fcf636d9eb8d645c24be3db2d599aba2d7e2955a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Average.scala#L66 | ||
| let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 4); | ||
| let new_scale = DECIMAL256_MAX_SCALE.min(*scale + 4); | ||
| Ok(DataType::Decimal256(new_precision, new_scale)) | ||
| } | ||
| DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), | ||
| arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), | ||
| DataType::Dictionary(_, dict_value_type) => { | ||
| avg_return_type(func_name, dict_value_type.as_ref()) | ||
| } | ||
| other => plan_err!("{func_name} does not support {other:?}"), | ||
| } | ||
| } | ||
|
|
||
| /// Internal sum type of an average | ||
| pub fn avg_sum_type(arg_type: &DataType) -> Result<DataType> { | ||
| match arg_type { | ||
| DataType::Decimal32(precision, scale) => { | ||
| // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) | ||
| let new_precision = DECIMAL32_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal32(new_precision, *scale)) | ||
| } | ||
| DataType::Decimal64(precision, scale) => { | ||
| // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) | ||
| let new_precision = DECIMAL64_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal64(new_precision, *scale)) | ||
| } | ||
| DataType::Decimal128(precision, scale) => { | ||
| // In the spark, the sum type of avg is DECIMAL(min(38,precision+10), s) | ||
| let new_precision = DECIMAL128_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal128(new_precision, *scale)) | ||
| } | ||
| DataType::Decimal256(precision, scale) => { | ||
| // In Spark the sum type of avg is DECIMAL(min(38,precision+10), s) | ||
| let new_precision = DECIMAL256_MAX_PRECISION.min(*precision + 10); | ||
| Ok(DataType::Decimal256(new_precision, *scale)) | ||
| } | ||
| DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), | ||
| arg_type if NUMERICS.contains(arg_type) => Ok(DataType::Float64), | ||
| DataType::Dictionary(_, dict_value_type) => { | ||
| avg_sum_type(dict_value_type.as_ref()) | ||
| } | ||
| other => plan_err!("AVG does not support {other:?}"), | ||
| } | ||
| } | ||
|
|
||
| pub fn is_sum_support_arg_type(arg_type: &DataType) -> bool { | ||
| match arg_type { | ||
| DataType::Dictionary(_, dict_value_type) => { | ||
| is_sum_support_arg_type(dict_value_type.as_ref()) | ||
| } | ||
| _ => matches!( | ||
| arg_type, | ||
| arg_type if NUMERICS.contains(arg_type) | ||
| || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) | ||
| ), | ||
| } | ||
| } | ||
|
|
||
| pub fn is_avg_support_arg_type(arg_type: &DataType) -> bool { | ||
| match arg_type { | ||
| DataType::Dictionary(_, dict_value_type) => { | ||
| is_avg_support_arg_type(dict_value_type.as_ref()) | ||
| } | ||
| _ => matches!( | ||
| arg_type, | ||
| arg_type if NUMERICS.contains(arg_type) | ||
| || matches!(arg_type, DataType::Decimal32(_, _) | DataType::Decimal64(_, _) |DataType::Decimal128(_, _) | DataType::Decimal256(_, _)) | ||
| ), | ||
| } | ||
| } | ||
|
|
||
| pub fn is_variance_support_arg_type(arg_type: &DataType) -> bool { | ||
| matches!( | ||
| arg_type, | ||
| arg_type if NUMERICS.contains(arg_type) | ||
| ) | ||
| } | ||
|
|
||
| pub fn is_covariance_support_arg_type(arg_type: &DataType) -> bool { | ||
| matches!( | ||
| arg_type, | ||
| arg_type if NUMERICS.contains(arg_type) | ||
| ) | ||
| } | ||
|
|
||
| pub fn is_correlation_support_arg_type(arg_type: &DataType) -> bool { | ||
| matches!( | ||
| arg_type, | ||
| arg_type if NUMERICS.contains(arg_type) | ||
| ) | ||
| } | ||
|
|
||
| pub fn is_integer_arg_type(arg_type: &DataType) -> bool { | ||
| arg_type.is_integer() | ||
| } | ||
|
|
||
| pub fn coerce_avg_type(func_name: &str, arg_types: &[DataType]) -> Result<Vec<DataType>> { | ||
| // Supported types smallint, int, bigint, real, double precision, decimal, or interval | ||
| // Refer to https://www.postgresql.org/docs/8.2/functions-aggregate.html doc | ||
| fn coerced_type(func_name: &str, data_type: &DataType) -> Result<DataType> { | ||
| match &data_type { | ||
| DataType::Decimal32(p, s) => Ok(DataType::Decimal32(*p, *s)), | ||
| DataType::Decimal64(p, s) => Ok(DataType::Decimal64(*p, *s)), | ||
| DataType::Decimal128(p, s) => Ok(DataType::Decimal128(*p, *s)), | ||
| DataType::Decimal256(p, s) => Ok(DataType::Decimal256(*p, *s)), | ||
| d if d.is_numeric() => Ok(DataType::Float64), | ||
| DataType::Duration(time_unit) => Ok(DataType::Duration(*time_unit)), | ||
| DataType::Dictionary(_, v) => coerced_type(func_name, v.as_ref()), | ||
| _ => { | ||
| plan_err!( | ||
| "The function {:?} does not support inputs of type {}.", | ||
| func_name, | ||
| data_type | ||
| ) | ||
| } | ||
| } | ||
| } | ||
| Ok(vec![coerced_type(func_name, &arg_types[0])?]) | ||
| } | ||
| #[cfg(test)] | ||
| mod tests { | ||
| use super::*; | ||
|
|
||
| #[test] | ||
| fn test_variance_return_data_type() -> Result<()> { | ||
| let data_type = DataType::Float64; | ||
| let result_type = variance_return_type(&data_type)?; | ||
| assert_eq!(DataType::Float64, result_type); | ||
|
|
||
| let data_type = DataType::Decimal128(36, 10); | ||
| assert!(variance_return_type(&data_type).is_err()); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_sum_return_data_type() -> Result<()> { | ||
| let data_type = DataType::Decimal128(10, 5); | ||
| let result_type = sum_return_type(&data_type)?; | ||
| assert_eq!(DataType::Decimal128(20, 5), result_type); | ||
|
|
||
| let data_type = DataType::Decimal128(36, 10); | ||
| let result_type = sum_return_type(&data_type)?; | ||
| assert_eq!(DataType::Decimal128(38, 10), result_type); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_covariance_return_data_type() -> Result<()> { | ||
| let data_type = DataType::Float64; | ||
| let result_type = covariance_return_type(&data_type)?; | ||
| assert_eq!(DataType::Float64, result_type); | ||
|
|
||
| let data_type = DataType::Decimal128(36, 10); | ||
| assert!(covariance_return_type(&data_type).is_err()); | ||
| Ok(()) | ||
| } | ||
|
|
||
| #[test] | ||
| fn test_correlation_return_data_type() -> Result<()> { | ||
| let data_type = DataType::Float64; | ||
| let result_type = correlation_return_type(&data_type)?; | ||
| assert_eq!(DataType::Float64, result_type); | ||
|
|
||
| let data_type = DataType::Decimal128(36, 10); | ||
| assert!(correlation_return_type(&data_type).is_err()); | ||
| Ok(()) | ||
| } | ||
| } | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I assume these were leftover from before moving to the Signature API