Skip to content

Commit 77fd002

Browse files
committed
Adding more unit tests on ffi aggregate udaf
1 parent f50ff52 commit 77fd002

File tree

2 files changed

+160
-20
lines changed

2 files changed

+160
-20
lines changed

datafusion/ffi/src/udaf/accumulator.rs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,11 @@ mod tests {
307307

308308
#[test]
309309
fn test_foreign_avg_accumulator() -> Result<()> {
310-
let boxed_accum: Box<dyn Accumulator> = Box::new(AvgAccumulator::default());
310+
let original_accum = AvgAccumulator::default();
311+
let original_size = original_accum.size();
312+
let original_supports_retract = original_accum.supports_retract_batch();
313+
314+
let boxed_accum: Box<dyn Accumulator> = Box::new(original_accum);
311315
let ffi_accum: FFI_Accumulator = boxed_accum.into();
312316
let mut foreign_accum: ForeignAccumulator = ffi_accum.into();
313317

@@ -341,6 +345,12 @@ mod tests {
341345
let avg = foreign_accum.evaluate()?;
342346
assert_eq!(avg, ScalarValue::Float64(Some(30.0)));
343347

348+
assert_eq!(original_size, foreign_accum.size());
349+
assert_eq!(
350+
original_supports_retract,
351+
foreign_accum.supports_retract_batch()
352+
);
353+
344354
Ok(())
345355
}
346356
}

datafusion/ffi/src/udaf/mod.rs

Lines changed: 149 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -528,25 +528,6 @@ impl AggregateUDFImpl for ForeignAggregateUDF {
528528
}
529529
}
530530

531-
#[cfg(test)]
532-
mod tests {
533-
use super::*;
534-
535-
#[test]
536-
fn test_round_trip_udaf() -> Result<()> {
537-
let original_udaf = datafusion::functions_aggregate::sum::Sum::new();
538-
let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
539-
540-
let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
541-
542-
let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
543-
544-
assert!(original_udaf.name() == foreign_udaf.name());
545-
546-
Ok(())
547-
}
548-
}
549-
550531
#[repr(C)]
551532
#[derive(Debug, StableAbi)]
552533
#[allow(non_camel_case_types)]
@@ -575,3 +556,152 @@ impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
575556
}
576557
}
577558
}
559+
560+
#[cfg(test)]
561+
mod tests {
562+
use arrow::datatypes::Schema;
563+
use datafusion::{
564+
common::create_array,
565+
functions_aggregate::sum::Sum,
566+
physical_expr::{LexOrdering, PhysicalSortExpr},
567+
physical_plan::expressions::col,
568+
scalar::ScalarValue,
569+
};
570+
571+
use super::*;
572+
573+
fn create_test_foreign_udaf(
574+
original_udaf: impl AggregateUDFImpl + 'static,
575+
) -> Result<AggregateUDF> {
576+
let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
577+
578+
let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
579+
580+
let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
581+
Ok(foreign_udaf.into())
582+
}
583+
584+
#[test]
585+
fn test_round_trip_udaf() -> Result<()> {
586+
let original_udaf = Sum::new();
587+
let original_name = original_udaf.name().to_owned();
588+
589+
let foreign_udaf = create_test_foreign_udaf(original_udaf)?;
590+
// let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
591+
592+
// let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
593+
594+
// let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
595+
// let foreign_udaf: AggregateUDF = foreign_udaf.into();
596+
597+
assert_eq!(original_name, foreign_udaf.name());
598+
Ok(())
599+
}
600+
601+
#[test]
602+
fn test_foreign_udaf_aliases() -> Result<()> {
603+
let foreign_udaf =
604+
create_test_foreign_udaf(Sum::new())?.with_aliases(["my_function"]);
605+
606+
let return_type = foreign_udaf.return_type(&[DataType::Float64])?;
607+
assert_eq!(return_type, DataType::Float64);
608+
Ok(())
609+
}
610+
611+
#[test]
612+
fn test_foreign_udaf_accumulator() -> Result<()> {
613+
let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
614+
615+
let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
616+
let acc_args = AccumulatorArgs {
617+
return_type: &DataType::Float64,
618+
schema: &schema,
619+
ignore_nulls: true,
620+
ordering_req: &LexOrdering::new(vec![PhysicalSortExpr {
621+
expr: col("a", &schema)?,
622+
options: Default::default(),
623+
}]),
624+
is_reversed: false,
625+
name: "round_trip",
626+
is_distinct: true,
627+
exprs: &[col("a", &schema)?],
628+
};
629+
let mut accumulator = foreign_udaf.accumulator(acc_args)?;
630+
let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
631+
accumulator.update_batch(&[values])?;
632+
let resultant_value = accumulator.evaluate()?;
633+
assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
634+
635+
Ok(())
636+
}
637+
638+
#[test]
639+
fn test_beneficial_ordering() -> Result<()> {
640+
let foreign_udaf = create_test_foreign_udaf(
641+
datafusion::functions_aggregate::first_last::FirstValue::new(),
642+
)?;
643+
644+
let foreign_udaf = foreign_udaf.with_beneficial_ordering(true)?.unwrap();
645+
646+
assert_eq!(
647+
foreign_udaf.order_sensitivity(),
648+
AggregateOrderSensitivity::Beneficial
649+
);
650+
651+
let a_field = Field::new("a", DataType::Float64, true);
652+
let state_fields = foreign_udaf.state_fields(StateFieldsArgs {
653+
name: "a",
654+
input_types: &[DataType::Float64],
655+
return_type: &DataType::Float64,
656+
ordering_fields: &[a_field.clone()],
657+
is_distinct: false,
658+
})?;
659+
660+
println!("{:#?}", state_fields);
661+
assert_eq!(state_fields.len(), 3);
662+
assert_eq!(state_fields[1], a_field);
663+
Ok(())
664+
}
665+
666+
#[test]
667+
fn test_sliding_accumulator() -> Result<()> {
668+
let foreign_udaf = create_test_foreign_udaf(Sum::new())?;
669+
670+
let schema = Schema::new(vec![Field::new("a", DataType::Float64, true)]);
671+
let acc_args = AccumulatorArgs {
672+
return_type: &DataType::Float64,
673+
schema: &schema,
674+
ignore_nulls: true,
675+
ordering_req: &LexOrdering::new(vec![PhysicalSortExpr {
676+
expr: col("a", &schema)?,
677+
options: Default::default(),
678+
}]),
679+
is_reversed: false,
680+
name: "round_trip",
681+
is_distinct: true,
682+
exprs: &[col("a", &schema)?],
683+
};
684+
685+
let mut accumulator = foreign_udaf.create_sliding_accumulator(acc_args)?;
686+
let values = create_array!(Float64, vec![10., 20., 30., 40., 50.]);
687+
accumulator.update_batch(&[values])?;
688+
let resultant_value = accumulator.evaluate()?;
689+
assert_eq!(resultant_value, ScalarValue::Float64(Some(150.)));
690+
691+
Ok(())
692+
}
693+
694+
fn test_round_trip_order_sensitivity(sensitivity: AggregateOrderSensitivity) {
695+
let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity.into();
696+
let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity.into();
697+
698+
assert_eq!(sensitivity, round_trip_sensitivity);
699+
}
700+
701+
#[test]
702+
fn test_round_trip_all_order_sensitivities() {
703+
test_round_trip_order_sensitivity(AggregateOrderSensitivity::Insensitive);
704+
test_round_trip_order_sensitivity(AggregateOrderSensitivity::HardRequirement);
705+
test_round_trip_order_sensitivity(AggregateOrderSensitivity::Beneficial);
706+
}
707+
}

0 commit comments

Comments
 (0)