@@ -528,25 +528,6 @@ impl AggregateUDFImpl for ForeignAggregateUDF {
528
528
}
529
529
}
530
530
531
- #[ cfg( test) ]
532
- mod tests {
533
- use super :: * ;
534
-
535
- #[ test]
536
- fn test_round_trip_udaf ( ) -> Result < ( ) > {
537
- let original_udaf = datafusion:: functions_aggregate:: sum:: Sum :: new ( ) ;
538
- let original_udaf = Arc :: new ( AggregateUDF :: from ( original_udaf) ) ;
539
-
540
- let local_udaf: FFI_AggregateUDF = Arc :: clone ( & original_udaf) . into ( ) ;
541
-
542
- let foreign_udaf: ForeignAggregateUDF = ( & local_udaf) . try_into ( ) ?;
543
-
544
- assert ! ( original_udaf. name( ) == foreign_udaf. name( ) ) ;
545
-
546
- Ok ( ( ) )
547
- }
548
- }
549
-
550
531
#[ repr( C ) ]
551
532
#[ derive( Debug , StableAbi ) ]
552
533
#[ allow( non_camel_case_types) ]
@@ -575,3 +556,152 @@ impl From<AggregateOrderSensitivity> for FFI_AggregateOrderSensitivity {
575
556
}
576
557
}
577
558
}
559
+
560
+ #[ cfg( test) ]
561
+ mod tests {
562
+ use arrow:: datatypes:: Schema ;
563
+ use datafusion:: {
564
+ common:: create_array,
565
+ functions_aggregate:: sum:: Sum ,
566
+ physical_expr:: { LexOrdering , PhysicalSortExpr } ,
567
+ physical_plan:: expressions:: col,
568
+ scalar:: ScalarValue ,
569
+ } ;
570
+
571
+ use super :: * ;
572
+
573
+ fn create_test_foreign_udaf (
574
+ original_udaf : impl AggregateUDFImpl + ' static ,
575
+ ) -> Result < AggregateUDF > {
576
+ let original_udaf = Arc :: new ( AggregateUDF :: from ( original_udaf) ) ;
577
+
578
+ let local_udaf: FFI_AggregateUDF = Arc :: clone ( & original_udaf) . into ( ) ;
579
+
580
+ let foreign_udaf: ForeignAggregateUDF = ( & local_udaf) . try_into ( ) ?;
581
+ Ok ( foreign_udaf. into ( ) )
582
+ }
583
+
584
+ #[ test]
585
+ fn test_round_trip_udaf ( ) -> Result < ( ) > {
586
+ let original_udaf = Sum :: new ( ) ;
587
+ let original_name = original_udaf. name ( ) . to_owned ( ) ;
588
+
589
+ let foreign_udaf = create_test_foreign_udaf ( original_udaf) ?;
590
+ // let original_udaf = Arc::new(AggregateUDF::from(original_udaf));
591
+
592
+ // let local_udaf: FFI_AggregateUDF = Arc::clone(&original_udaf).into();
593
+
594
+ // let foreign_udaf: ForeignAggregateUDF = (&local_udaf).try_into()?;
595
+ // let foreign_udaf: AggregateUDF = foreign_udaf.into();
596
+
597
+ assert_eq ! ( original_name, foreign_udaf. name( ) ) ;
598
+ Ok ( ( ) )
599
+ }
600
+
601
+ #[ test]
602
+ fn test_foreign_udaf_aliases ( ) -> Result < ( ) > {
603
+ let foreign_udaf =
604
+ create_test_foreign_udaf ( Sum :: new ( ) ) ?. with_aliases ( [ "my_function" ] ) ;
605
+
606
+ let return_type = foreign_udaf. return_type ( & [ DataType :: Float64 ] ) ?;
607
+ assert_eq ! ( return_type, DataType :: Float64 ) ;
608
+ Ok ( ( ) )
609
+ }
610
+
611
+ #[ test]
612
+ fn test_foreign_udaf_accumulator ( ) -> Result < ( ) > {
613
+ let foreign_udaf = create_test_foreign_udaf ( Sum :: new ( ) ) ?;
614
+
615
+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float64 , true ) ] ) ;
616
+ let acc_args = AccumulatorArgs {
617
+ return_type : & DataType :: Float64 ,
618
+ schema : & schema,
619
+ ignore_nulls : true ,
620
+ ordering_req : & LexOrdering :: new ( vec ! [ PhysicalSortExpr {
621
+ expr: col( "a" , & schema) ?,
622
+ options: Default :: default ( ) ,
623
+ } ] ) ,
624
+ is_reversed : false ,
625
+ name : "round_trip" ,
626
+ is_distinct : true ,
627
+ exprs : & [ col ( "a" , & schema) ?] ,
628
+ } ;
629
+ let mut accumulator = foreign_udaf. accumulator ( acc_args) ?;
630
+ let values = create_array ! ( Float64 , vec![ 10. , 20. , 30. , 40. , 50. ] ) ;
631
+ accumulator. update_batch ( & [ values] ) ?;
632
+ let resultant_value = accumulator. evaluate ( ) ?;
633
+ assert_eq ! ( resultant_value, ScalarValue :: Float64 ( Some ( 150. ) ) ) ;
634
+
635
+ Ok ( ( ) )
636
+ }
637
+
638
+ #[ test]
639
+ fn test_beneficial_ordering ( ) -> Result < ( ) > {
640
+ let foreign_udaf = create_test_foreign_udaf (
641
+ datafusion:: functions_aggregate:: first_last:: FirstValue :: new ( ) ,
642
+ ) ?;
643
+
644
+ let foreign_udaf = foreign_udaf. with_beneficial_ordering ( true ) ?. unwrap ( ) ;
645
+
646
+ assert_eq ! (
647
+ foreign_udaf. order_sensitivity( ) ,
648
+ AggregateOrderSensitivity :: Beneficial
649
+ ) ;
650
+
651
+ let a_field = Field :: new ( "a" , DataType :: Float64 , true ) ;
652
+ let state_fields = foreign_udaf. state_fields ( StateFieldsArgs {
653
+ name : "a" ,
654
+ input_types : & [ DataType :: Float64 ] ,
655
+ return_type : & DataType :: Float64 ,
656
+ ordering_fields : & [ a_field. clone ( ) ] ,
657
+ is_distinct : false ,
658
+ } ) ?;
659
+
660
+ println ! ( "{:#?}" , state_fields) ;
661
+ assert_eq ! ( state_fields. len( ) , 3 ) ;
662
+ assert_eq ! ( state_fields[ 1 ] , a_field) ;
663
+ Ok ( ( ) )
664
+ }
665
+
666
+ #[ test]
667
+ fn test_sliding_accumulator ( ) -> Result < ( ) > {
668
+ let foreign_udaf = create_test_foreign_udaf ( Sum :: new ( ) ) ?;
669
+
670
+ let schema = Schema :: new ( vec ! [ Field :: new( "a" , DataType :: Float64 , true ) ] ) ;
671
+ let acc_args = AccumulatorArgs {
672
+ return_type : & DataType :: Float64 ,
673
+ schema : & schema,
674
+ ignore_nulls : true ,
675
+ ordering_req : & LexOrdering :: new ( vec ! [ PhysicalSortExpr {
676
+ expr: col( "a" , & schema) ?,
677
+ options: Default :: default ( ) ,
678
+ } ] ) ,
679
+ is_reversed : false ,
680
+ name : "round_trip" ,
681
+ is_distinct : true ,
682
+ exprs : & [ col ( "a" , & schema) ?] ,
683
+ } ;
684
+
685
+ let mut accumulator = foreign_udaf. create_sliding_accumulator ( acc_args) ?;
686
+ let values = create_array ! ( Float64 , vec![ 10. , 20. , 30. , 40. , 50. ] ) ;
687
+ accumulator. update_batch ( & [ values] ) ?;
688
+ let resultant_value = accumulator. evaluate ( ) ?;
689
+ assert_eq ! ( resultant_value, ScalarValue :: Float64 ( Some ( 150. ) ) ) ;
690
+
691
+ Ok ( ( ) )
692
+ }
693
+
694
+ fn test_round_trip_order_sensitivity ( sensitivity : AggregateOrderSensitivity ) {
695
+ let ffi_sensitivity: FFI_AggregateOrderSensitivity = sensitivity. into ( ) ;
696
+ let round_trip_sensitivity: AggregateOrderSensitivity = ffi_sensitivity. into ( ) ;
697
+
698
+ assert_eq ! ( sensitivity, round_trip_sensitivity) ;
699
+ }
700
+
701
+ #[ test]
702
+ fn test_round_trip_all_order_sensitivities ( ) {
703
+ test_round_trip_order_sensitivity ( AggregateOrderSensitivity :: Insensitive ) ;
704
+ test_round_trip_order_sensitivity ( AggregateOrderSensitivity :: HardRequirement ) ;
705
+ test_round_trip_order_sensitivity ( AggregateOrderSensitivity :: Beneficial ) ;
706
+ }
707
+ }
0 commit comments