|  | 
| 27 | 27 | //! * Return type: a function `(arg_types) -> return_type`. E.g. for min, ([f32]) -> f32, ([f64]) -> f64. | 
| 28 | 28 | 
 | 
| 29 | 29 | use super::{ | 
| 30 |  | -    functions::{Signature, Volatility}, | 
|  | 30 | +    functions::{Signature, TypeSignature, Volatility}, | 
| 31 | 31 |     Accumulator, AggregateExpr, PhysicalExpr, | 
| 32 | 32 | }; | 
| 33 | 33 | use crate::error::{DataFusionError, Result}; | 
| @@ -80,6 +80,8 @@ pub enum AggregateFunction { | 
| 80 | 80 |     CovariancePop, | 
| 81 | 81 |     /// Correlation | 
| 82 | 82 |     Correlation, | 
|  | 83 | +    /// Approximate continuous percentile function | 
|  | 84 | +    ApproxPercentileCont, | 
| 83 | 85 | } | 
| 84 | 86 | 
 | 
| 85 | 87 | impl fmt::Display for AggregateFunction { | 
| @@ -110,6 +112,7 @@ impl FromStr for AggregateFunction { | 
| 110 | 112 |             "covar_samp" => AggregateFunction::Covariance, | 
| 111 | 113 |             "covar_pop" => AggregateFunction::CovariancePop, | 
| 112 | 114 |             "corr" => AggregateFunction::Correlation, | 
|  | 115 | +            "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont, | 
| 113 | 116 |             _ => { | 
| 114 | 117 |                 return Err(DataFusionError::Plan(format!( | 
| 115 | 118 |                     "There is no built-in function named {}", | 
| @@ -157,6 +160,7 @@ pub fn return_type( | 
| 157 | 160 |             coerced_data_types[0].clone(), | 
| 158 | 161 |             true, | 
| 159 | 162 |         )))), | 
|  | 163 | +        AggregateFunction::ApproxPercentileCont => Ok(coerced_data_types[0].clone()), | 
| 160 | 164 |     } | 
| 161 | 165 | } | 
| 162 | 166 | 
 | 
| @@ -331,6 +335,20 @@ pub fn create_aggregate_expr( | 
| 331 | 335 |                 "CORR(DISTINCT) aggregations are not available".to_string(), | 
| 332 | 336 |             )); | 
| 333 | 337 |         } | 
|  | 338 | +        (AggregateFunction::ApproxPercentileCont, false) => { | 
|  | 339 | +            Arc::new(expressions::ApproxPercentileCont::new( | 
|  | 340 | +                // Pass in the desired percentile expr | 
|  | 341 | +                coerced_phy_exprs, | 
|  | 342 | +                name, | 
|  | 343 | +                return_type, | 
|  | 344 | +            )?) | 
|  | 345 | +        } | 
|  | 346 | +        (AggregateFunction::ApproxPercentileCont, true) => { | 
|  | 347 | +            return Err(DataFusionError::NotImplemented( | 
|  | 348 | +                "approx_percentile_cont(DISTINCT) aggregations are not available" | 
|  | 349 | +                    .to_string(), | 
|  | 350 | +            )); | 
|  | 351 | +        } | 
| 334 | 352 |     }) | 
| 335 | 353 | } | 
| 336 | 354 | 
 | 
| @@ -389,17 +407,25 @@ pub(super) fn signature(fun: &AggregateFunction) -> Signature { | 
| 389 | 407 |         AggregateFunction::Correlation => { | 
| 390 | 408 |             Signature::uniform(2, NUMERICS.to_vec(), Volatility::Immutable) | 
| 391 | 409 |         } | 
|  | 410 | +        AggregateFunction::ApproxPercentileCont => Signature::one_of( | 
|  | 411 | +            // Accept any numeric value paired with a float64 percentile | 
|  | 412 | +            NUMERICS | 
|  | 413 | +                .iter() | 
|  | 414 | +                .map(|t| TypeSignature::Exact(vec![t.clone(), DataType::Float64])) | 
|  | 415 | +                .collect(), | 
|  | 416 | +            Volatility::Immutable, | 
|  | 417 | +        ), | 
| 392 | 418 |     } | 
| 393 | 419 | } | 
| 394 | 420 | 
 | 
| 395 | 421 | #[cfg(test)] | 
| 396 | 422 | mod tests { | 
| 397 | 423 |     use super::*; | 
| 398 |  | -    use crate::error::Result; | 
| 399 | 424 |     use crate::physical_plan::expressions::{ | 
| 400 |  | -        ApproxDistinct, ArrayAgg, Avg, Correlation, Count, Covariance, DistinctArrayAgg, | 
| 401 |  | -        DistinctCount, Max, Min, Stddev, Sum, Variance, | 
|  | 425 | +        ApproxDistinct, ApproxPercentileCont, ArrayAgg, Avg, Correlation, Count, | 
|  | 426 | +        Covariance, DistinctArrayAgg, DistinctCount, Max, Min, Stddev, Sum, Variance, | 
| 402 | 427 |     }; | 
|  | 428 | +    use crate::{error::Result, scalar::ScalarValue}; | 
| 403 | 429 | 
 | 
| 404 | 430 |     #[test] | 
| 405 | 431 |     fn test_count_arragg_approx_expr() -> Result<()> { | 
| @@ -513,6 +539,59 @@ mod tests { | 
| 513 | 539 |         Ok(()) | 
| 514 | 540 |     } | 
| 515 | 541 | 
 | 
|  | 542 | +    #[test] | 
|  | 543 | +    fn test_agg_approx_percentile_phy_expr() { | 
|  | 544 | +        for data_type in NUMERICS { | 
|  | 545 | +            let input_schema = | 
|  | 546 | +                Schema::new(vec![Field::new("c1", data_type.clone(), true)]); | 
|  | 547 | +            let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![ | 
|  | 548 | +                Arc::new( | 
|  | 549 | +                    expressions::Column::new_with_schema("c1", &input_schema).unwrap(), | 
|  | 550 | +                ), | 
|  | 551 | +                Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(0.2)))), | 
|  | 552 | +            ]; | 
|  | 553 | +            let result_agg_phy_exprs = create_aggregate_expr( | 
|  | 554 | +                &AggregateFunction::ApproxPercentileCont, | 
|  | 555 | +                false, | 
|  | 556 | +                &input_phy_exprs[..], | 
|  | 557 | +                &input_schema, | 
|  | 558 | +                "c1", | 
|  | 559 | +            ) | 
|  | 560 | +            .expect("failed to create aggregate expr"); | 
|  | 561 | + | 
|  | 562 | +            assert!(result_agg_phy_exprs.as_any().is::<ApproxPercentileCont>()); | 
|  | 563 | +            assert_eq!("c1", result_agg_phy_exprs.name()); | 
|  | 564 | +            assert_eq!( | 
|  | 565 | +                Field::new("c1", data_type.clone(), false), | 
|  | 566 | +                result_agg_phy_exprs.field().unwrap() | 
|  | 567 | +            ); | 
|  | 568 | +        } | 
|  | 569 | +    } | 
|  | 570 | + | 
|  | 571 | +    #[test] | 
|  | 572 | +    fn test_agg_approx_percentile_invalid_phy_expr() { | 
|  | 573 | +        for data_type in NUMERICS { | 
|  | 574 | +            let input_schema = | 
|  | 575 | +                Schema::new(vec![Field::new("c1", data_type.clone(), true)]); | 
|  | 576 | +            let input_phy_exprs: Vec<Arc<dyn PhysicalExpr>> = vec![ | 
|  | 577 | +                Arc::new( | 
|  | 578 | +                    expressions::Column::new_with_schema("c1", &input_schema).unwrap(), | 
|  | 579 | +                ), | 
|  | 580 | +                Arc::new(expressions::Literal::new(ScalarValue::Float64(Some(4.2)))), | 
|  | 581 | +            ]; | 
|  | 582 | +            let err = create_aggregate_expr( | 
|  | 583 | +                &AggregateFunction::ApproxPercentileCont, | 
|  | 584 | +                false, | 
|  | 585 | +                &input_phy_exprs[..], | 
|  | 586 | +                &input_schema, | 
|  | 587 | +                "c1", | 
|  | 588 | +            ) | 
|  | 589 | +            .expect_err("should fail due to invalid percentile"); | 
|  | 590 | + | 
|  | 591 | +            assert!(matches!(err, DataFusionError::Plan(_))); | 
|  | 592 | +        } | 
|  | 593 | +    } | 
|  | 594 | + | 
| 516 | 595 |     #[test] | 
| 517 | 596 |     fn test_min_max_expr() -> Result<()> { | 
| 518 | 597 |         let funcs = vec![AggregateFunction::Min, AggregateFunction::Max]; | 
|  | 
0 commit comments