Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move Regr_* functions to use UDAF #10898

Merged
merged 5 commits into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 1 addition & 55 deletions datafusion/expr/src/aggregate_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,24 +47,6 @@ pub enum AggregateFunction {
NthValue,
/// Correlation
Correlation,
/// Slope from linear regression
RegrSlope,
/// Intercept from linear regression
RegrIntercept,
/// Number of input rows in which both expressions are not null
RegrCount,
/// R-squared value from linear regression
RegrR2,
/// Average of the independent variable
RegrAvgx,
/// Average of the dependent variable
RegrAvgy,
/// Sum of squares of the independent variable
RegrSXX,
/// Sum of squares of the dependent variable
RegrSYY,
/// Sum of products of pairs of numbers
RegrSXY,
/// Approximate continuous percentile function
ApproxPercentileCont,
/// Approximate continuous percentile function with weight
Expand Down Expand Up @@ -96,15 +78,6 @@ impl AggregateFunction {
ArrayAgg => "ARRAY_AGG",
NthValue => "NTH_VALUE",
Correlation => "CORR",
RegrSlope => "REGR_SLOPE",
RegrIntercept => "REGR_INTERCEPT",
RegrCount => "REGR_COUNT",
RegrR2 => "REGR_R2",
RegrAvgx => "REGR_AVGX",
RegrAvgy => "REGR_AVGY",
RegrSXX => "REGR_SXX",
RegrSYY => "REGR_SYY",
RegrSXY => "REGR_SXY",
ApproxPercentileCont => "APPROX_PERCENTILE_CONT",
ApproxPercentileContWithWeight => "APPROX_PERCENTILE_CONT_WITH_WEIGHT",
Grouping => "GROUPING",
Expand Down Expand Up @@ -144,15 +117,6 @@ impl FromStr for AggregateFunction {
"string_agg" => AggregateFunction::StringAgg,
// statistical
"corr" => AggregateFunction::Correlation,
"regr_slope" => AggregateFunction::RegrSlope,
"regr_intercept" => AggregateFunction::RegrIntercept,
"regr_count" => AggregateFunction::RegrCount,
"regr_r2" => AggregateFunction::RegrR2,
"regr_avgx" => AggregateFunction::RegrAvgx,
"regr_avgy" => AggregateFunction::RegrAvgy,
"regr_sxx" => AggregateFunction::RegrSXX,
"regr_syy" => AggregateFunction::RegrSYY,
"regr_sxy" => AggregateFunction::RegrSXY,
// approximate
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
"approx_percentile_cont_with_weight" => {
Expand Down Expand Up @@ -205,15 +169,6 @@ impl AggregateFunction {
AggregateFunction::Correlation => {
correlation_return_type(&coerced_data_types[0])
}
AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
| AggregateFunction::RegrR2
| AggregateFunction::RegrAvgx
| AggregateFunction::RegrAvgy
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => Ok(DataType::Float64),
AggregateFunction::Avg => avg_return_type(&coerced_data_types[0]),
AggregateFunction::ArrayAgg => Ok(DataType::List(Arc::new(Field::new(
"item",
Expand Down Expand Up @@ -278,16 +233,7 @@ impl AggregateFunction {
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::NthValue => Signature::any(2, Volatility::Immutable),
AggregateFunction::Correlation
| AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
| AggregateFunction::RegrR2
| AggregateFunction::RegrAvgx
| AggregateFunction::RegrAvgy
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => {
AggregateFunction::Correlation => {
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
}
AggregateFunction::ApproxPercentileCont => {
Expand Down
21 changes: 0 additions & 21 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,27 +159,6 @@ pub fn coerce_types(
}
Ok(vec![Float64, Float64])
}
AggregateFunction::RegrSlope
| AggregateFunction::RegrIntercept
| AggregateFunction::RegrCount
| AggregateFunction::RegrR2
| AggregateFunction::RegrAvgx
| AggregateFunction::RegrAvgy
| AggregateFunction::RegrSXX
| AggregateFunction::RegrSYY
| AggregateFunction::RegrSXY => {
let valid_types = [NUMERICS.to_vec(), vec![Null]].concat();
let input_types_valid = // number of input already checked before
valid_types.contains(&input_types[0]) && valid_types.contains(&input_types[1]);
if !input_types_valid {
return plan_err!(
"The function {:?} does not support inputs of type {:?}.",
agg_fun,
input_types[0]
);
}
Ok(vec![Float64, Float64])
}
AggregateFunction::ApproxPercentileCont => {
if !is_approx_percentile_cont_supported_arg_type(&input_types[0]) {
return plan_err!(
Expand Down
19 changes: 19 additions & 0 deletions datafusion/functions-aggregate/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ pub mod covariance;
pub mod first_last;
pub mod hyperloglog;
pub mod median;
pub mod regr;
pub mod stddev;
pub mod sum;
pub mod variance;
Expand All @@ -85,6 +86,15 @@ pub mod expr_fn {
pub use super::first_last::first_value;
pub use super::first_last::last_value;
pub use super::median::median;
pub use super::regr::regr_avgx;
pub use super::regr::regr_avgy;
pub use super::regr::regr_count;
pub use super::regr::regr_intercept;
pub use super::regr::regr_r2;
pub use super::regr::regr_slope;
pub use super::regr::regr_sxx;
pub use super::regr::regr_sxy;
pub use super::regr::regr_syy;
pub use super::stddev::stddev;
pub use super::stddev::stddev_pop;
pub use super::sum::sum;
Expand All @@ -102,6 +112,15 @@ pub fn all_default_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
covariance::covar_pop_udaf(),
median::median_udaf(),
count::count_udaf(),
regr::regr_slope_udaf(),
regr::regr_intercept_udaf(),
regr::regr_count_udaf(),
regr::regr_r2_udaf(),
regr::regr_avgx_udaf(),
regr::regr_avgy_udaf(),
regr::regr_sxx_udaf(),
regr::regr_syy_udaf(),
regr::regr_sxy_udaf(),
variance::var_samp_udaf(),
variance::var_pop_udaf(),
stddev::stddev_udaf(),
Expand Down
14 changes: 11 additions & 3 deletions datafusion/functions-aggregate/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@
// specific language governing permissions and limitations
// under the License.

macro_rules! make_udaf_expr_and_func {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
macro_rules! make_udaf_expr {
($EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
// "fluent expr_fn" style function
#[doc = $DOC]
pub fn $EXPR_FN(
Expand All @@ -48,7 +48,12 @@ macro_rules! make_udaf_expr_and_func {
None,
))
}
};
}

macro_rules! make_udaf_expr_and_func {
($UDAF:ty, $EXPR_FN:ident, $($arg:ident)*, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
make_udaf_expr!($EXPR_FN, $($arg)*, $DOC, $AGGREGATE_UDF_FN);
create_func!($UDAF, $AGGREGATE_UDF_FN);
};
($UDAF:ty, $EXPR_FN:ident, $DOC:expr, $AGGREGATE_UDF_FN:ident) => {
Expand All @@ -73,6 +78,9 @@ macro_rules! make_udaf_expr_and_func {

macro_rules! create_func {
($UDAF:ty, $AGGREGATE_UDF_FN:ident) => {
create_func!($UDAF, $AGGREGATE_UDF_FN, <$UDAF>::default());
};
($UDAF:ty, $AGGREGATE_UDF_FN:ident, $CREATE:expr) => {
paste::paste! {
/// Singleton instance of [$UDAF], ensures the UDAF is only created once
/// named STATIC_$(UDAF). For example `STATIC_FirstValue`
Expand All @@ -86,7 +94,7 @@ macro_rules! create_func {
pub fn $AGGREGATE_UDF_FN() -> std::sync::Arc<datafusion_expr::AggregateUDF> {
[< STATIC_ $UDAF >]
.get_or_init(|| {
std::sync::Arc::new(datafusion_expr::AggregateUDF::from(<$UDAF>::default()))
std::sync::Arc::new(datafusion_expr::AggregateUDF::from($CREATE))
})
.clone()
}
Expand Down
Loading
Loading