@@ -69,11 +69,11 @@ pub struct FFI_AggregateUDF {
6969 /// FFI equivalent to the `volatility` of a [`AggregateUDF`]
7070 pub volatility : FFI_Volatility ,
7171
72- /// Determines the return type of the underlying [`AggregateUDF`] based on the
73- /// argument types .
74- pub return_type : unsafe extern "C" fn (
72+ /// Determines the return field of the underlying [`AggregateUDF`] based on the
73+ /// argument fields .
74+ pub return_field : unsafe extern "C" fn (
7575 udaf : & Self ,
76- arg_types : RVec < WrappedSchema > ,
76+ arg_fields : RVec < WrappedSchema > ,
7777 ) -> RResult < WrappedSchema , RString > ,
7878
7979 /// FFI equivalent to the `is_nullable` of a [`AggregateUDF`]
@@ -160,20 +160,22 @@ impl FFI_AggregateUDF {
160160 }
161161}
162162
163- unsafe extern "C" fn return_type_fn_wrapper (
163+ unsafe extern "C" fn return_field_fn_wrapper (
164164 udaf : & FFI_AggregateUDF ,
165- arg_types : RVec < WrappedSchema > ,
165+ arg_fields : RVec < WrappedSchema > ,
166166) -> RResult < WrappedSchema , RString > {
167167 let udaf = udaf. inner ( ) ;
168168
169- let arg_types = rresult_return ! ( rvec_wrapped_to_vec_datatype ( & arg_types ) ) ;
169+ let arg_fields = rresult_return ! ( rvec_wrapped_to_vec_fieldref ( & arg_fields ) ) ;
170170
171- let return_type = udaf
172- . return_type ( & arg_types)
173- . and_then ( |v| FFI_ArrowSchema :: try_from ( v) . map_err ( DataFusionError :: from) )
171+ let return_field = udaf
172+ . return_field ( & arg_fields)
173+ . and_then ( |v| {
174+ FFI_ArrowSchema :: try_from ( v. as_ref ( ) ) . map_err ( DataFusionError :: from)
175+ } )
174176 . map ( WrappedSchema ) ;
175177
176- rresult ! ( return_type )
178+ rresult ! ( return_field )
177179}
178180
179181unsafe extern "C" fn accumulator_fn_wrapper (
@@ -346,7 +348,7 @@ impl From<Arc<AggregateUDF>> for FFI_AggregateUDF {
346348 is_nullable,
347349 volatility,
348350 aliases,
349- return_type : return_type_fn_wrapper ,
351+ return_field : return_field_fn_wrapper ,
350352 accumulator : accumulator_fn_wrapper,
351353 create_sliding_accumulator : create_sliding_accumulator_fn_wrapper,
352354 create_groups_accumulator : create_groups_accumulator_fn_wrapper,
@@ -425,14 +427,22 @@ impl AggregateUDFImpl for ForeignAggregateUDF {
425427 & self . signature
426428 }
427429
428- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
429- let arg_types = vec_datatype_to_rvec_wrapped ( arg_types) ?;
430+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
431+ unimplemented ! ( )
432+ }
433+
434+ fn return_field ( & self , arg_fields : & [ FieldRef ] ) -> Result < FieldRef > {
435+ let arg_fields = vec_fieldref_to_rvec_wrapped ( arg_fields) ?;
430436
431- let result = unsafe { ( self . udaf . return_type ) ( & self . udaf , arg_types ) } ;
437+ let result = unsafe { ( self . udaf . return_field ) ( & self . udaf , arg_fields ) } ;
432438
433439 let result = df_result ! ( result) ;
434440
435- result. and_then ( |r| ( & r. 0 ) . try_into ( ) . map_err ( DataFusionError :: from) )
441+ result. and_then ( |r| {
442+ Field :: try_from ( & r. 0 )
443+ . map ( Arc :: new)
444+ . map_err ( DataFusionError :: from)
445+ } )
436446 }
437447
438448 fn is_nullable ( & self ) -> bool {
@@ -608,9 +618,43 @@ mod tests {
608618 physical_expr:: PhysicalSortExpr , physical_plan:: expressions:: col,
609619 scalar:: ScalarValue ,
610620 } ;
621+ use std:: any:: Any ;
622+ use std:: collections:: HashMap ;
611623
612624 use super :: * ;
613625
626+ #[ derive( Default , Debug , Hash , Eq , PartialEq ) ]
627+ struct SumWithCopiedMetadata {
628+ inner : Sum ,
629+ }
630+
631+ impl AggregateUDFImpl for SumWithCopiedMetadata {
632+ fn as_any ( & self ) -> & dyn Any {
633+ self
634+ }
635+
636+ fn name ( & self ) -> & str {
637+ self . inner . name ( )
638+ }
639+
640+ fn signature ( & self ) -> & Signature {
641+ self . inner . signature ( )
642+ }
643+
644+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
645+ unimplemented ! ( )
646+ }
647+
648+ fn return_field ( & self , arg_fields : & [ FieldRef ] ) -> Result < FieldRef > {
649+ // Copy the input field, so any metadata gets returned
650+ Ok ( Arc :: clone ( & arg_fields[ 0 ] ) )
651+ }
652+
653+ fn accumulator ( & self , acc_args : AccumulatorArgs ) -> Result < Box < dyn Accumulator > > {
654+ self . inner . accumulator ( acc_args)
655+ }
656+ }
657+
614658 fn create_test_foreign_udaf (
615659 original_udaf : impl AggregateUDFImpl + ' static ,
616660 ) -> Result < AggregateUDF > {
@@ -644,8 +688,11 @@ mod tests {
644688 let foreign_udaf =
645689 create_test_foreign_udaf ( Sum :: new ( ) ) ?. with_aliases ( [ "my_function" ] ) ;
646690
647- let return_type = foreign_udaf. return_type ( & [ DataType :: Float64 ] ) ?;
648- assert_eq ! ( return_type, DataType :: Float64 ) ;
691+ let return_field =
692+ foreign_udaf
693+ . return_field ( & [ Field :: new ( "a" , DataType :: Float64 , true ) . into ( ) ] ) ?;
694+ let return_type = return_field. data_type ( ) ;
695+ assert_eq ! ( return_type, & DataType :: Float64 ) ;
649696 Ok ( ( ) )
650697 }
651698
@@ -673,6 +720,31 @@ mod tests {
673720 Ok ( ( ) )
674721 }
675722
723+ #[ test]
724+ fn test_round_trip_udaf_metadata ( ) -> Result < ( ) > {
725+ let original_udaf = SumWithCopiedMetadata :: default ( ) ;
726+ let original_udaf = Arc :: new ( AggregateUDF :: from ( original_udaf) ) ;
727+
728+ // Convert to FFI format
729+ let local_udaf: FFI_AggregateUDF = Arc :: clone ( & original_udaf) . into ( ) ;
730+
731+ // Convert back to native format
732+ let foreign_udaf: ForeignAggregateUDF = ( & local_udaf) . try_into ( ) ?;
733+ let foreign_udaf: AggregateUDF = foreign_udaf. into ( ) ;
734+
735+ let metadata: HashMap < String , String > =
736+ [ ( "a_key" . to_string ( ) , "a_value" . to_string ( ) ) ]
737+ . into_iter ( )
738+ . collect ( ) ;
739+ let input_field = Arc :: new (
740+ Field :: new ( "a" , DataType :: Float64 , false ) . with_metadata ( metadata. clone ( ) ) ,
741+ ) ;
742+ let return_field = foreign_udaf. return_field ( & [ input_field] ) ?;
743+
744+ assert_eq ! ( & metadata, return_field. metadata( ) ) ;
745+ Ok ( ( ) )
746+ }
747+
676748 #[ test]
677749 fn test_beneficial_ordering ( ) -> Result < ( ) > {
678750 let foreign_udaf = create_test_foreign_udaf (
0 commit comments