Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Minor: Signature check for UDAF #10147

Merged
merged 2 commits into from
Apr 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion datafusion/expr/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ pub use logical_plan::*;
pub use operator::Operator;
pub use partition_evaluator::PartitionEvaluator;
pub use signature::{
FuncMonotonicity, Signature, TypeSignature, Volatility, TIMEZONE_WILDCARD,
ArrayFunctionSignature, FuncMonotonicity, Signature, TypeSignature, Volatility,
TIMEZONE_WILDCARD,
};
pub use table_source::{TableProviderFilterPushDown, TableSource, TableType};
pub use udaf::{AggregateUDF, AggregateUDFImpl};
Expand Down
21 changes: 9 additions & 12 deletions datafusion/expr/src/type_coercion/aggregates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ pub fn coerce_types(
) -> Result<Vec<DataType>> {
use DataType::*;
// Validate input_types matches (at least one of) the func signature.
check_arg_count(agg_fun, input_types, &signature.type_signature)?;
check_arg_count(agg_fun.name(), input_types, &signature.type_signature)?;

match agg_fun {
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
Expand Down Expand Up @@ -323,17 +323,16 @@ pub fn coerce_types(
/// This method DOES NOT validate the argument types - only that (at least one,
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
/// number of input types.
fn check_arg_count(
agg_fun: &AggregateFunction,
pub fn check_arg_count(
func_name: &str,
input_types: &[DataType],
signature: &TypeSignature,
) -> Result<()> {
match signature {
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
if input_types.len() != *agg_count {
return plan_err!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
"The function {func_name} expects {:?} arguments, but {:?} were provided",
agg_count,
input_types.len()
);
Expand All @@ -342,8 +341,7 @@ fn check_arg_count(
TypeSignature::Exact(types) => {
if types.len() != input_types.len() {
return plan_err!(
"The function {:?} expects {:?} arguments, but {:?} were provided",
agg_fun,
"The function {func_name} expects {:?} arguments, but {:?} were provided",
types.len(),
input_types.len()
);
Expand All @@ -352,19 +350,18 @@ fn check_arg_count(
TypeSignature::OneOf(variants) => {
let ok = variants
.iter()
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
.any(|v| check_arg_count(func_name, input_types, v).is_ok());
if !ok {
return plan_err!(
"The function {:?} does not accept {:?} function arguments.",
agg_fun,
"The function {func_name} does not accept {:?} function arguments.",
input_types.len()
);
}
}
TypeSignature::VariadicAny => {
if input_types.is_empty() {
return plan_err!(
"The function {agg_fun:?} expects at least one argument"
"The function {func_name} expects at least one argument"
);
}
}
Expand Down Expand Up @@ -594,7 +591,7 @@ mod tests {
let input_types = vec![DataType::Int64, DataType::Int32];
let signature = fun.signature();
let result = coerce_types(&fun, &input_types, &signature);
assert_eq!("Error during planning: The function Min expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());
assert_eq!("Error during planning: The function MIN expects 1 arguments, but 2 were provided", result.unwrap_err().strip_backtrace());

// test input args is invalid data type for sum or avg
let fun = AggregateFunction::Sum;
Expand Down
14 changes: 12 additions & 2 deletions datafusion/functions-aggregate/src/first_last.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ use datafusion_common::{
use datafusion_expr::function::AccumulatorArgs;
use datafusion_expr::type_coercion::aggregates::NUMERICS;
use datafusion_expr::utils::format_state_name;
use datafusion_expr::{Accumulator, AggregateUDFImpl, Expr, Signature, Volatility};
use datafusion_expr::{
Accumulator, AggregateUDFImpl, ArrayFunctionSignature, Expr, Signature,
TypeSignature, Volatility,
};
use datafusion_physical_expr_common::aggregate::utils::{
down_cast_any_ref, get_sort_options, ordering_fields,
};
Expand Down Expand Up @@ -73,7 +76,14 @@ impl FirstValue {
pub fn new() -> Self {
Self {
aliases: vec![String::from("FIRST_VALUE")],
signature: Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable),
signature: Signature::one_of(
vec![
// TODO: we can introduce more strict signature that only numeric of array types are allowed
TypeSignature::ArraySignature(ArrayFunctionSignature::Array),
TypeSignature::Uniform(1, NUMERICS.to_vec()),
],
Volatility::Immutable,
),
}
}
}
Expand Down
7 changes: 7 additions & 0 deletions datafusion/physical-expr-common/src/aggregate/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ pub mod utils;

use arrow::datatypes::{DataType, Field, Schema};
use datafusion_common::{not_impl_err, Result};
use datafusion_expr::type_coercion::aggregates::check_arg_count;
use datafusion_expr::{
function::AccumulatorArgs, Accumulator, AggregateUDF, Expr, GroupsAccumulator,
};
Expand Down Expand Up @@ -46,6 +47,12 @@ pub fn create_aggregate_expr(
.map(|arg| arg.data_type(schema))
.collect::<Result<Vec<_>>>()?;

check_arg_count(
fun.name(),
&input_exprs_types,
&fun.signature().type_signature,
)?;

let ordering_types = ordering_req
.iter()
.map(|e| e.expr.data_type(schema))
Expand Down