Skip to content

Commit cfb655d

Browse files
authored
approx_quantile() aggregation function (#1539)
* feat: implement TDigest for approx quantile Adds a [TDigest] implementation providing approximate quantile estimations of large inputs using a small amount of (bounded) memory. A TDigest is most accurate near either "end" of the quantile range (that is, 0.1, 0.9, 0.95, etc) due to the use of a scalaing function that increases resolution at the tails. The paper claims single digit part per million errors for q ≤ 0.001 or q ≥ 0.999 using 100 centroids, and in practice I have found accuracy to be more than acceptable for an apprixmate function across the entire quantile range. The implementation is a modified copy of https://github.com/MnO2/t-digest, itself a Rust port of [Facebook's C++ implementation]. Both Facebook's implementation, and Mn02's Rust port are Apache 2.0 licensed. [TDigest]: https://arxiv.org/abs/1902.04023 [Facebook's C++ implementation]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h * feat: approx_quantile aggregation Adds the ApproxQuantile physical expression, plumbing & test cases. The function signature is: approx_quantile(column, quantile) Where column can be any numeric type (that can be cast to a float64) and quantile is a float64 literal between 0 and 1. * feat: approx_quantile dataframe function Adds the approx_quantile() dataframe function, and exports it in the prelude. * refactor: bastilla approx_quantile support Adds bastilla wire encoding for approx_quantile. Adding support for this required modifying the AggregateExprNode proto message to support propigating multiple LogicalExprNode aggregate arguments - all the existing aggregations take a single argument, so this wasn't needed before. This commit adds "repeated" to the expr field, which I believe is backwards compatible as described here: https://developers.google.com/protocol-buffers/docs/proto3#updating Specifically, adding "repeated" to an existing message field: "For ... message fields, optional is compatible with repeated" No existing tests needed fixing, and a new roundtrip test is included that covers the change to allow multiple expr. * refactor: use input type as return type Casts the calculated quantile value to the same type as the input data. * fixup! refactor: bastilla approx_quantile support * refactor: rebase onto main * refactor: validate quantile value Ensures the quantile values is between 0 and 1, emitting a plan error if not. * refactor: rename to approx_percentile_cont * refactor: clippy lints
1 parent 7bec762 commit cfb655d

File tree

16 files changed

+1485
-47
lines changed

16 files changed

+1485
-47
lines changed

ballista/rust/core/proto/ballista.proto

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,11 +176,12 @@ enum AggregateFunction {
176176
STDDEV=11;
177177
STDDEV_POP=12;
178178
CORRELATION=13;
179+
APPROX_PERCENTILE_CONT = 14;
179180
}
180181

181182
message AggregateExprNode {
182183
AggregateFunction aggr_function = 1;
183-
LogicalExprNode expr = 2;
184+
repeated LogicalExprNode expr = 2;
184185
}
185186

186187
enum BuiltInWindowFunction {

ballista/rust/core/src/serde/logical_plan/from_proto.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1065,7 +1065,11 @@ impl TryInto<Expr> for &protobuf::LogicalExprNode {
10651065

10661066
Ok(Expr::AggregateFunction {
10671067
fun,
1068-
args: vec![parse_required_expr(&expr.expr)?],
1068+
args: expr
1069+
.expr
1070+
.iter()
1071+
.map(|e| e.try_into())
1072+
.collect::<Result<Vec<_>, _>>()?,
10691073
distinct: false, //TODO
10701074
})
10711075
}

ballista/rust/core/src/serde/logical_plan/mod.rs

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,16 +24,14 @@ mod roundtrip_tests {
2424
use super::super::{super::error::Result, protobuf};
2525
use crate::error::BallistaError;
2626
use core::panic;
27-
use datafusion::arrow::datatypes::UnionMode;
28-
use datafusion::logical_plan::Repartition;
2927
use datafusion::{
30-
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit},
28+
arrow::datatypes::{DataType, Field, IntervalUnit, Schema, TimeUnit, UnionMode},
3129
datasource::object_store::local::LocalFileSystem,
3230
logical_plan::{
3331
col, CreateExternalTable, Expr, LogicalPlan, LogicalPlanBuilder,
34-
Partitioning, ToDFSchema,
32+
Partitioning, Repartition, ToDFSchema,
3533
},
36-
physical_plan::functions::BuiltinScalarFunction::Sqrt,
34+
physical_plan::{aggregates, functions::BuiltinScalarFunction::Sqrt},
3735
prelude::*,
3836
scalar::ScalarValue,
3937
sql::parser::FileType,
@@ -1001,4 +999,17 @@ mod roundtrip_tests {
1001999

10021000
Ok(())
10031001
}
1002+
1003+
#[test]
1004+
fn roundtrip_approx_percentile_cont() -> Result<()> {
1005+
let test_expr = Expr::AggregateFunction {
1006+
fun: aggregates::AggregateFunction::ApproxPercentileCont,
1007+
args: vec![col("bananas"), lit(0.42)],
1008+
distinct: false,
1009+
};
1010+
1011+
roundtrip_test!(test_expr, protobuf::LogicalExprNode, Expr);
1012+
1013+
Ok(())
1014+
}
10041015
}

ballista/rust/core/src/serde/logical_plan/to_proto.rs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1074,6 +1074,9 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
10741074
AggregateFunction::ApproxDistinct => {
10751075
protobuf::AggregateFunction::ApproxDistinct
10761076
}
1077+
AggregateFunction::ApproxPercentileCont => {
1078+
protobuf::AggregateFunction::ApproxPercentileCont
1079+
}
10771080
AggregateFunction::ArrayAgg => protobuf::AggregateFunction::ArrayAgg,
10781081
AggregateFunction::Min => protobuf::AggregateFunction::Min,
10791082
AggregateFunction::Max => protobuf::AggregateFunction::Max,
@@ -1099,11 +1102,13 @@ impl TryInto<protobuf::LogicalExprNode> for &Expr {
10991102
}
11001103
};
11011104

1102-
let arg = &args[0];
1103-
let aggregate_expr = Box::new(protobuf::AggregateExprNode {
1105+
let aggregate_expr = protobuf::AggregateExprNode {
11041106
aggr_function: aggr_function.into(),
1105-
expr: Some(Box::new(arg.try_into()?)),
1106-
});
1107+
expr: args
1108+
.iter()
1109+
.map(|v| v.try_into())
1110+
.collect::<Result<Vec<_>, _>>()?,
1111+
};
11071112
Ok(protobuf::LogicalExprNode {
11081113
expr_type: Some(ExprType::AggregateExpr(aggregate_expr)),
11091114
})
@@ -1334,6 +1339,7 @@ impl From<&AggregateFunction> for protobuf::AggregateFunction {
13341339
AggregateFunction::Stddev => Self::Stddev,
13351340
AggregateFunction::StddevPop => Self::StddevPop,
13361341
AggregateFunction::Correlation => Self::Correlation,
1342+
AggregateFunction::ApproxPercentileCont => Self::ApproxPercentileCont,
13371343
}
13381344
}
13391345
}

ballista/rust/core/src/serde/mod.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -129,6 +129,9 @@ impl From<protobuf::AggregateFunction> for AggregateFunction {
129129
protobuf::AggregateFunction::Stddev => AggregateFunction::Stddev,
130130
protobuf::AggregateFunction::StddevPop => AggregateFunction::StddevPop,
131131
protobuf::AggregateFunction::Correlation => AggregateFunction::Correlation,
132+
protobuf::AggregateFunction::ApproxPercentileCont => {
133+
AggregateFunction::ApproxPercentileCont
134+
}
132135
}
133136
}
134137
}

datafusion/src/logical_plan/expr.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1647,6 +1647,15 @@ pub fn approx_distinct(expr: Expr) -> Expr {
16471647
}
16481648
}
16491649

1650+
/// Calculate an approximation of the specified `percentile` for `expr`.
1651+
pub fn approx_percentile_cont(expr: Expr, percentile: Expr) -> Expr {
1652+
Expr::AggregateFunction {
1653+
fun: aggregates::AggregateFunction::ApproxPercentileCont,
1654+
distinct: false,
1655+
args: vec![expr, percentile],
1656+
}
1657+
}
1658+
16501659
// TODO(kszucs): this seems buggy, unary_scalar_expr! is used for many
16511660
// varying arity functions
16521661
/// Create an convenience function representing a unary scalar function

datafusion/src/logical_plan/mod.rs

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ pub use builder::{
3636
pub use dfschema::{DFField, DFSchema, DFSchemaRef, ToDFSchema};
3737
pub use display::display_schema;
3838
pub use expr::{
39-
abs, acos, and, approx_distinct, array, ascii, asin, atan, avg, binary_expr,
40-
bit_length, btrim, case, ceil, character_length, chr, col, columnize_expr,
41-
combine_filters, concat, concat_ws, cos, count, count_distinct, create_udaf,
42-
create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields, floor, in_list,
43-
initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2, lower, lpad, ltrim,
44-
max, md5, min, normalize_col, normalize_cols, now, octet_length, or, random,
45-
regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
39+
abs, acos, and, approx_distinct, approx_percentile_cont, array, ascii, asin, atan,
40+
avg, binary_expr, bit_length, btrim, case, ceil, character_length, chr, col,
41+
columnize_expr, combine_filters, concat, concat_ws, cos, count, count_distinct,
42+
create_udaf, create_udf, date_part, date_trunc, digest, exp, exprlist_to_fields,
43+
floor, in_list, initcap, left, length, lit, lit_timestamp_nano, ln, log10, log2,
44+
lower, lpad, ltrim, max, md5, min, normalize_col, normalize_cols, now, octet_length,
45+
or, random, regexp_match, regexp_replace, repeat, replace, replace_col, reverse,
4646
rewrite_sort_cols_by_aggs, right, round, rpad, rtrim, sha224, sha256, sha384, sha512,
4747
signum, sin, split_part, sqrt, starts_with, strpos, substr, sum, tan, to_hex,
4848
translate, trim, trunc, unalias, unnormalize_col, unnormalize_cols, upper, when,

datafusion/src/physical_plan/aggregates.rs

Lines changed: 83 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
//! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64.
2828
2929
use super::{
30-
functions::{Signature, Volatility},
30+
functions::{Signature, TypeSignature, Volatility},
3131
Accumulator, AggregateExpr, PhysicalExpr,
3232
};
3333
use crate::error::{DataFusionError, Result};
@@ -80,6 +80,8 @@ pub enum AggregateFunction {
8080
CovariancePop,
8181
/// Correlation
8282
Correlation,
83+
/// Approximate continuous percentile function
84+
ApproxPercentileCont,
8385
}
8486

8587
impl fmt::Display for AggregateFunction {
@@ -110,6 +112,7 @@ impl FromStr for AggregateFunction {
110112
"covar_samp" => AggregateFunction::Covariance,
111113
"covar_pop" => AggregateFunction::CovariancePop,
112114
"corr" => AggregateFunction::Correlation,
115+
"approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
113116
_ => {
114117
return Err(DataFusionError::Plan(format!(
115118
"There is no built-in function named {}",
@@ -157,6 +160,7 @@ pub fn return_type(
157160
coerced_data_types[0].clone(),
158161
true,
159162
)))),
163+
AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()),
160164
}
161165
}
162166

@@ -331,6 +335,20 @@ pub fn create_aggregate_expr(
331335
"CORR(DISTINCT) aggregations are not available".to_string(),
332336
));
333337
}
338+
(AggregateFunction::ApproxPercentileCont, false) => {
339+
Arc::new(expressions::ApproxPercentileCont::new(
340+
// Pass in the desired percentile expr
341+
coerced_phy_exprs,
342+
name,
343+
return_type,
344+
)?)
345+
}
346+
(AggregateFunction::ApproxPercentileCont, true) => {
347+
return Err(DataFusionError::NotImplemented(
348+
"approx_percentile_cont(DISTINCT) aggregations are not available"
349+
.to_string(),
350+
));
351+
}
334352
})
335353
}
336354

@@ -389,17 +407,25 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature {
389407
AggregateFunction::Correlation => {
390408
Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable)
391409
}
410+
AggregateFunction::ApproxPercentileCont => Signature::one_of(
411+
// Accept any numeric value paired with a float64 percentile
412+
NUMERICS
413+
.iter()
414+
.map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64]))
415+
.collect(),
416+
Volatility::Immutable,
417+
),
392418
}
393419
}
394420

395421
#[cfg(test)]
396422
mod tests {
397423
use super::*;
398-
use crate::error::Result;
399424
use crate::physical_plan::expressions::{
400-
ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg,
401-
DistinctCount, Max, Min, Stddev, Sum, Variance,
425+
ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count,
426+
Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance,
402427
};
428+
use crate::{error::Result, scalar::ScalarValue};
403429

404430
#[test]
405431
fn test_count_arragg_approx_expr() -> Result<()> {
@@ -513,6 +539,59 @@ mod tests {
513539
Ok(())
514540
}
515541

542+
#[test]
543+
fn test_agg_approx_percentile_phy_expr() {
544+
for data_type in NUMERICS {
545+
let input_schema =
546+
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
547+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
548+
Arc::new(
549+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
550+
),
551+
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))),
552+
];
553+
let result_agg_phy_exprs = create_aggregate_expr(
554+
&AggregateFunction::ApproxPercentileCont,
555+
false,
556+
&input_phy_exprs[..],
557+
&input_schema,
558+
"c1",
559+
)
560+
.expect("failed to create aggregate expr");
561+
562+
assert!(result_agg_phy_exprs.as_any().is::<ApproxPercentileCont>());
563+
assert_eq!("c1", result_agg_phy_exprs.name());
564+
assert_eq!(
565+
Field::new("c1", data_type.clone(), false),
566+
result_agg_phy_exprs.field().unwrap()
567+
);
568+
}
569+
}
570+
571+
#[test]
572+
fn test_agg_approx_percentile_invalid_phy_expr() {
573+
for data_type in NUMERICS {
574+
let input_schema =
575+
Schema::new(vec![Field::new("c1", data_type.clone(), true)]);
576+
let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![
577+
Arc::new(
578+
expressions::Column::new_with_schema("c1", &input_schema).unwrap(),
579+
),
580+
Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))),
581+
];
582+
let err = create_aggregate_expr(
583+
&AggregateFunction::ApproxPercentileCont,
584+
false,
585+
&input_phy_exprs[..],
586+
&input_schema,
587+
"c1",
588+
)
589+
.expect_err("should fail due to invalid percentile");
590+
591+
assert!(matches!(err, DataFusionError::Plan(_)));
592+
}
593+
}
594+
516595
#[test]
517596
fn test_min_max_expr() -> Result<()> {
518597
let funcs = vec![AggregateFunction::Min, AggregateFunction::Max];

0 commit comments

Comments
 (0)