Skip to content

Commit 2ee004d

Browse files
committed
feat: more test for avg distinct in rust api
1 parent 0b9d749 commit 2ee004d

File tree

2 files changed

+59
-52
lines changed

2 files changed

+59
-52
lines changed

datafusion/core/tests/dataframe/mod.rs

+58-52
Original file line numberDiff line numberDiff line change
@@ -496,30 +496,32 @@ async fn aggregate() -> Result<()> {
496496
// build plan using DataFrame API
497497
let df = test_table().await?;
498498
let group_expr = vec![col("c1")];
499+
let avg_distinct = avg(col("c12")).distinct().build().unwrap();
499500
let aggr_expr = vec![
500501
min(col("c12")),
501502
max(col("c12")),
502503
avg(col("c12")),
503504
sum(col("c12")),
504505
count(col("c12")),
505506
count_distinct(col("c12")),
507+
avg_distinct,
506508
];
507509

508510
let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
509511

510512
assert_snapshot!(
511513
batches_to_sort_string(&df),
512-
@r###"
513-
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
514-
| c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) |
515-
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
516-
| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 |
517-
| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 |
518-
| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 |
519-
| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 |
520-
| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 |
521-
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+
522-
"###
514+
@r"
515+
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+
516+
| c1 | min(aggregate_test_100.c12) | max(aggregate_test_100.c12) | avg(aggregate_test_100.c12) | sum(aggregate_test_100.c12) | count(aggregate_test_100.c12) | count(DISTINCT aggregate_test_100.c12) | avg(DISTINCT aggregate_test_100.c12) |
517+
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+
518+
| a | 0.02182578039211991 | 0.9800193410444061 | 0.48754517466109415 | 10.238448667882977 | 21 | 21 | 0.48754517466109415 |
519+
| b | 0.04893135681998029 | 0.9185813970744787 | 0.41040709263815384 | 7.797734760124923 | 19 | 19 | 0.41040709263815384 |
520+
| c | 0.0494924465469434 | 0.991517828651004 | 0.6600456536439784 | 13.860958726523545 | 21 | 21 | 0.6600456536439784 |
521+
| d | 0.061029375346466685 | 0.9748360509016578 | 0.48855379387549824 | 8.793968289758968 | 18 | 18 | 0.48855379387549824 |
522+
| e | 0.01479305307777301 | 0.9965400387585364 | 0.48600669271341534 | 10.206140546981722 | 21 | 21 | 0.48600669271341534 |
523+
+----+-----------------------------+-----------------------------+-----------------------------+-----------------------------+-------------------------------+----------------------------------------+--------------------------------------+
524+
"
523525
);
524526

525527
Ok(())
@@ -530,6 +532,7 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> {
530532
// build plan using DataFrame API
531533
let df = test_table().await?;
532534
let group_expr = vec![col("c1")];
535+
let avg_distinct = avg(col("c12")).distinct().build().unwrap();
533536
let aggr_expr = vec![
534537
min(col("c12")),
535538
max(col("c12")),
@@ -538,6 +541,7 @@ async fn aggregate_assert_no_empty_batches() -> Result<()> {
538541
count(col("c12")),
539542
count_distinct(col("c12")),
540543
median(col("c12")),
544+
avg_distinct,
541545
];
542546

543547
let df: Vec<RecordBatch> = df.aggregate(group_expr, aggr_expr)?.collect().await?;
@@ -3354,13 +3358,15 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> {
33543358
vec![col("c1"), col("c2")],
33553359
]));
33563360

3361+
let avg_distinct = avg(col("c3")).distinct().build().unwrap();
33573362
let df = aggregates_table(&ctx)
33583363
.await?
33593364
.aggregate(
33603365
vec![grouping_set_expr],
33613366
vec![
33623367
sum(col("c3")).alias("sum_c3"),
33633368
avg(col("c3")).alias("avg_c3"),
3369+
avg_distinct.alias("avg_distinct_c3"),
33643370
],
33653371
)?
33663372
.sort(vec![
@@ -3372,47 +3378,47 @@ async fn test_grouping_set_array_agg_with_overflow() -> Result<()> {
33723378

33733379
assert_snapshot!(
33743380
batches_to_string(&results),
3375-
@r###"
3376-
+----+----+--------+---------------------+
3377-
| c1 | c2 | sum_c3 | avg_c3 |
3378-
+----+----+--------+---------------------+
3379-
| | 5 | -194 | -13.857142857142858 |
3380-
| | 4 | 29 | 1.2608695652173914 |
3381-
| | 3 | 395 | 20.789473684210527 |
3382-
| | 2 | 184 | 8.363636363636363 |
3383-
| | 1 | 367 | 16.681818181818183 |
3384-
| e | | 847 | 40.333333333333336 |
3385-
| e | 5 | -22 | -11.0 |
3386-
| e | 4 | 261 | 37.285714285714285 |
3387-
| e | 3 | 192 | 48.0 |
3388-
| e | 2 | 189 | 37.8 |
3389-
| e | 1 | 227 | 75.66666666666667 |
3390-
| d | | 458 | 25.444444444444443 |
3391-
| d | 5 | -99 | -49.5 |
3392-
| d | 4 | 162 | 54.0 |
3393-
| d | 3 | 124 | 41.333333333333336 |
3394-
| d | 2 | 328 | 109.33333333333333 |
3395-
| d | 1 | -57 | -8.142857142857142 |
3396-
| c | | -28 | -1.3333333333333333 |
3397-
| c | 5 | 24 | 12.0 |
3398-
| c | 4 | -43 | -10.75 |
3399-
| c | 3 | 190 | 47.5 |
3400-
| c | 2 | -389 | -55.57142857142857 |
3401-
| c | 1 | 190 | 47.5 |
3402-
| b | | -111 | -5.842105263157895 |
3403-
| b | 5 | -1 | -0.2 |
3404-
| b | 4 | -223 | -44.6 |
3405-
| b | 3 | -84 | -42.0 |
3406-
| b | 2 | 102 | 25.5 |
3407-
| b | 1 | 95 | 31.666666666666668 |
3408-
| a | | -385 | -18.333333333333332 |
3409-
| a | 5 | -96 | -32.0 |
3410-
| a | 4 | -128 | -32.0 |
3411-
| a | 3 | -27 | -4.5 |
3412-
| a | 2 | -46 | -15.333333333333334 |
3413-
| a | 1 | -88 | -17.6 |
3414-
+----+----+--------+---------------------+
3415-
"###
3381+
@r"
3382+
+----+----+--------+---------------------+---------------------+
3383+
| c1 | c2 | sum_c3 | avg_c3 | avg_distinct_c3 |
3384+
+----+----+--------+---------------------+---------------------+
3385+
| | 5 | -194 | -13.857142857142858 | -13.857142857142858 |
3386+
| | 4 | 29 | 1.2608695652173914 | 1.2608695652173914 |
3387+
| | 3 | 395 | 20.789473684210527 | 20.789473684210527 |
3388+
| | 2 | 184 | 8.363636363636363 | 8.363636363636363 |
3389+
| | 1 | 367 | 16.681818181818183 | 16.681818181818183 |
3390+
| e | | 847 | 40.333333333333336 | 40.333333333333336 |
3391+
| e | 5 | -22 | -11.0 | -11.0 |
3392+
| e | 4 | 261 | 37.285714285714285 | 37.285714285714285 |
3393+
| e | 3 | 192 | 48.0 | 48.0 |
3394+
| e | 2 | 189 | 37.8 | 37.8 |
3395+
| e | 1 | 227 | 75.66666666666667 | 75.66666666666667 |
3396+
| d | | 458 | 25.444444444444443 | 25.444444444444443 |
3397+
| d | 5 | -99 | -49.5 | -49.5 |
3398+
| d | 4 | 162 | 54.0 | 54.0 |
3399+
| d | 3 | 124 | 41.333333333333336 | 41.333333333333336 |
3400+
| d | 2 | 328 | 109.33333333333333 | 109.33333333333333 |
3401+
| d | 1 | -57 | -8.142857142857142 | -8.142857142857142 |
3402+
| c | | -28 | -1.3333333333333333 | -1.3333333333333333 |
3403+
| c | 5 | 24 | 12.0 | 12.0 |
3404+
| c | 4 | -43 | -10.75 | -10.75 |
3405+
| c | 3 | 190 | 47.5 | 47.5 |
3406+
| c | 2 | -389 | -55.57142857142857 | -55.57142857142857 |
3407+
| c | 1 | 190 | 47.5 | 47.5 |
3408+
| b | | -111 | -5.842105263157895 | -5.842105263157895 |
3409+
| b | 5 | -1 | -0.2 | -0.2 |
3410+
| b | 4 | -223 | -44.6 | -44.6 |
3411+
| b | 3 | -84 | -42.0 | -42.0 |
3412+
| b | 2 | 102 | 25.5 | 25.5 |
3413+
| b | 1 | 95 | 31.666666666666668 | 31.666666666666668 |
3414+
| a | | -385 | -18.333333333333332 | -18.333333333333332 |
3415+
| a | 5 | -96 | -32.0 | -32.0 |
3416+
| a | 4 | -128 | -32.0 | -32.0 |
3417+
| a | 3 | -27 | -4.5 | -4.5 |
3418+
| a | 2 | -46 | -15.333333333333334 | -15.333333333333334 |
3419+
| a | 1 | -88 | -17.6 | -17.6 |
3420+
+----+----+--------+---------------------+---------------------+
3421+
"
34163422
);
34173423

34183424
Ok(())

datafusion/proto/tests/cases/roundtrip_logical_plan.rs

+1
Original file line numberDiff line numberDiff line change
@@ -957,6 +957,7 @@ async fn roundtrip_expr_api() -> Result<()> {
957957
functions_window::nth_value::last_value(lit(1)),
958958
functions_window::nth_value::nth_value(lit(1), 1),
959959
avg(lit(1.5)),
960+
avg(lit(1.5)).distinct().build().unwrap(),
960961
covar_samp(lit(1.5), lit(2.2)),
961962
covar_pop(lit(1.5), lit(2.2)),
962963
corr(lit(1.5), lit(2.2)),

0 commit comments

Comments
 (0)