Skip to content

Commit d5fc006

Browse files
committed
feat: approx_quantile aggregation
Adds the ApproxQuantile physical expression, plumbing & test cases. The function signature is: approx_quantile(column, quantile) Where column can be any numeric type (that can be cast to a float64) and quantile is a float64 literal between 0 and 1.
1 parent b72d21c commit d5fc006

File tree

5 files changed

+537
-23
lines changed

5 files changed

+537
-23
lines changed

datafusion/src/physical_plan/aggregates.rs

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.
2828
2929
use super::{
30-
functions::{Signature, Volatility},
30+
functions::{Signature, TypeSignature, Volatility},
3131
Accumulator, AggregateExpr, PhysicalExpr,
3232
};
3333
use crate::error::{DataFusionError, Result};
@@ -74,6 +74,8 @@ pub enum AggregateFunction {
7474
Stddev,
7575
/// Standard Deviation (Population)
7676
StddevPop,
77+
/// Approximate quantile function
78+
ApproxQuantile,
7779
}
7880

7981
impl fmt::Display for AggregateFunction {
@@ -100,6 +102,7 @@ impl FromStr for AggregateFunction {
100102
"stddev" => AggregateFunction::Stddev,
101103
"stddev_samp" => AggregateFunction::Stddev,
102104
"stddev_pop" => AggregateFunction::StddevPop,
105+
"approx_quantile" => AggregateFunction::ApproxQuantile,
103106
_ => {
104107
return Err(DataFusionError::Plan(format!(
105108
"There is no built-in function named {}",
@@ -142,6 +145,7 @@ pub fn return_type(
142145
coerced_data_types[0].clone(),
143146
true,
144147
)))),
148+
AggregateFunction::ApproxQuantile => Ok(DataType::Float64),
145149
}
146150
}
147151

@@ -279,6 +283,19 @@ pub fn create_aggregate_expr(
279283
"STDDEV_POP(DISTINCT) aggregations are not available".to_string(),
280284
));
281285
}
286+
(AggregateFunction::ApproxQuantile, false) => {
287+
Arc::new(expressions::ApproxQuantile::new(
288+
// Pass in the desired quantile expr
289+
coerced_phy_exprs,
290+
name,
291+
return_type,
292+
)?)
293+
}
294+
(AggregateFunction::ApproxQuantile, true) => {
295+
return Err(DataFusionError::NotImplemented(
296+
"approx_quantile(DISTINCT) aggregations are not available".to_string(),
297+
));
298+
}
282299
})
283300
}
284301

@@ -331,17 +348,28 @@ pub fn signature(fun: &AggregateFunction) -> Signature {
331348
| AggregateFunction::StddevPop => {
332349
Signature::uniform(1, NUMERICS.to_vec(), Volatility::Immutable)
333350
}
351+
AggregateFunction::ApproxQuantile => Signature::one_of(
352+
// Accept any numeric value paired with a float64 quantile
353+
NUMERICS
354+
.iter()
355+
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
356+
.collect(),
357+
Volatility::Immutable,
358+
),
334359
}
335360
}
336361

337362
#[cfg(test)]
338363
mod tests {
339364
use super::*;
340-
use crate::error::DataFusionError::NotImplemented;
341-
use crate::error::Result;
342365
use crate::physical_plan::distinct_expressions::DistinctCount;
343366
use crate::physical_plan::expressions::{
344-
ApproxDistinct, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum, Variance,
367+
ApproxDistinct, ApproxQuantile, ArrayAgg, Avg, Count, Max, Min, Stddev, Sum,
368+
Variance,
369+
};
370+
use crate::{
371+
error::{DataFusionError::NotImplemented, Result},
372+
scalar::ScalarValue,
345373
};
346374

347375
#[test]
@@ -458,6 +486,35 @@ mod tests {
458486
Ok(())
459487
}
460488

489+
#[test]
490+
fn test_agg_approx_quantile_phy_expr() {
491+
for data_type in NUMERICS {
492+
let input_schema =
493+
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
494+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
495+
Arc::new(
496+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
497+
),
498+
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
499+
];
500+
let result_agg_phy_exprs = create_aggregate_expr(
501+
&AggregateFunction::ApproxQuantile,
502+
false,
503+
&input_phy_exprs[..],
504+
&input_schema,
505+
"c1",
506+
)
507+
.expect("failed to create aggregate expr");
508+
509+
assert!(result_agg_phy_exprs.as_any().is::<ApproxQuantile>());
510+
assert_eq!("c1", result_agg_phy_exprs.name());
511+
assert_eq!(
512+
Field::new("c1", DataType::Float64, false),
513+
result_agg_phy_exprs.field().unwrap()
514+
);
515+
}
516+
}
517+
461518
#[test]
462519
fn test_min_max_expr() -> Result<()> {
463520
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];

datafusion/src/physical_plan/coercion_rule/aggregate_rule.rs

Lines changed: 95 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
//! Support the coercion rule for aggregate function.
1919
20-
use crate::arrow::datatypes::Schema;
2120
use crate::error::{DataFusionError, Result};
2221
use crate::physical_plan::aggregates::AggregateFunction;
2322
use crate::physical_plan::expressions::{
@@ -26,6 +25,10 @@ use crate::physical_plan::expressions::{
2625
};
2726
use crate::physical_plan::functions::{Signature, TypeSignature};
2827
use crate::physical_plan::PhysicalExpr;
28+
use crate::{
29+
arrow::datatypes::Schema,
30+
physical_plan::expressions::is_approx_quantile_supported_arg_type,
31+
};
2932
use arrow::datatypes::DataType;
3033
use std::ops::Deref;
3134
use std::sync::Arc;
@@ -37,24 +40,9 @@ pub(crate) fn coerce_types(
3740
input_types: &[DataType],
3841
signature: &Signature,
3942
) -> Result<Vec<DataType>> {
40-
match signature.type_signature {
41-
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
42-
if input_types.len() != agg_count {
43-
return Err(DataFusionError::Plan(format!(
44-
"The function {:?} expects {:?} arguments, but {:?} were provided",
45-
agg_fun,
46-
agg_count,
47-
input_types.len()
48-
)));
49-
}
50-
}
51-
_ => {
52-
return Err(DataFusionError::Internal(format!(
53-
"Aggregate functions do not support this {:?}",
54-
signature
55-
)));
56-
}
57-
};
43+
// Validate input_types matches (at least one of) the func signature.
44+
check_arg_count(agg_fun, input_types, &signature.type_signature)?;
45+
5846
match agg_fun {
5947
AggregateFunction::Count | AggregateFunction::ApproxDistinct => {
6048
Ok(input_types.to_vec())
@@ -123,7 +111,75 @@ pub(crate) fn coerce_types(
123111
}
124112
Ok(input_types.to_vec())
125113
}
114+
AggregateFunction::ApproxQuantile => {
115+
if !is_approx_quantile_supported_arg_type(&input_types[0]) {
116+
return Err(DataFusionError::Plan(format!(
117+
"The function {:?} does not support inputs of type {:?}.",
118+
agg_fun, input_types[0]
119+
)));
120+
}
121+
if !matches!(input_types[1], DataType::Float64) {
122+
return Err(DataFusionError::Plan(format!(
123+
"The quantile argument for {:?} must be Float64, not {:?}.",
124+
agg_fun, input_types[1]
125+
)));
126+
}
127+
Ok(input_types.to_vec())
128+
}
129+
}
130+
}
131+
132+
/// Validate the length of `input_types` matches the `signature` for `agg_fun`.
133+
///
134+
/// This method DOES NOT validate the argument types - only that (at least one,
135+
/// in the case of [`TypeSignature::OneOf`]) signature matches the desired
136+
/// number of input types.
137+
fn check_arg_count(
138+
agg_fun: &AggregateFunction,
139+
input_types: &[DataType],
140+
signature: &TypeSignature,
141+
) -> Result<()> {
142+
match signature {
143+
TypeSignature::Uniform(agg_count, _) | TypeSignature::Any(agg_count) => {
144+
if input_types.len() != *agg_count {
145+
return Err(DataFusionError::Plan(format!(
146+
"The function {:?} expects {:?} arguments, but {:?} were provided",
147+
agg_fun,
148+
agg_count,
149+
input_types.len()
150+
)));
151+
}
152+
}
153+
TypeSignature::Exact(types) => {
154+
if types.len() != input_types.len() {
155+
return Err(DataFusionError::Plan(format!(
156+
"The function {:?} expects {:?} arguments, but {:?} were provided",
157+
agg_fun,
158+
types.len(),
159+
input_types.len()
160+
)));
161+
}
162+
}
163+
TypeSignature::OneOf(variants) => {
164+
let ok = variants
165+
.iter()
166+
.any(|v| check_arg_count(agg_fun, input_types, v).is_ok());
167+
if !ok {
168+
return Err(DataFusionError::Plan(format!(
169+
"The function {:?} does not accept {:?} function arguments.",
170+
agg_fun,
171+
input_types.len()
172+
)));
173+
}
174+
}
175+
_ => {
176+
return Err(DataFusionError::Internal(format!(
177+
"Aggregate functions do not support this {:?}",
178+
signature
179+
)));
180+
}
126181
}
182+
Ok(())
127183
}
128184

129185
fn get_min_max_result_type(input_types: &[DataType]) -> Result<Vec<DataType>> {
@@ -239,5 +295,25 @@ mod tests {
239295
assert_eq!(*input_type, result.unwrap());
240296
}
241297
}
298+
299+
// ApproxQuantile input types
300+
let input_types = vec![
301+
vec![DataType::Int8, DataType::Float64],
302+
vec![DataType::Int16, DataType::Float64],
303+
vec![DataType::Int32, DataType::Float64],
304+
vec![DataType::Int64, DataType::Float64],
305+
vec![DataType::UInt8, DataType::Float64],
306+
vec![DataType::UInt16, DataType::Float64],
307+
vec![DataType::UInt32, DataType::Float64],
308+
vec![DataType::UInt64, DataType::Float64],
309+
vec![DataType::Float32, DataType::Float64],
310+
vec![DataType::Float64, DataType::Float64],
311+
];
312+
for input_type in &input_types {
313+
let signature = aggregates::signature(&AggregateFunction::ApproxQuantile);
314+
let result =
315+
coerce_types(&AggregateFunction::ApproxQuantile, input_type, &signature);
316+
assert_eq!(*input_type, result.unwrap());
317+
}
242318
}
243319
}

0 commit comments

Comments
 (0)