@@ -22,12 +22,16 @@ use arrow_schema::{DataType, Field, Schema};
2222use datafusion:: prelude:: * ;
2323use datafusion:: { execution:: registry:: FunctionRegistry , test_util} ;
2424use datafusion_common:: cast:: as_float64_array;
25- use datafusion_common:: { assert_batches_eq, cast:: as_int32_array, Result , ScalarValue } ;
25+ use datafusion_common:: {
26+ assert_batches_eq, assert_batches_sorted_eq, cast:: as_int32_array, not_impl_err,
27+ plan_err, DFSchema , DataFusionError , Result , ScalarValue ,
28+ } ;
2629use datafusion_expr:: {
27- create_udaf, create_udf, Accumulator , ColumnarValue , LogicalPlanBuilder , ScalarUDF ,
28- ScalarUDFImpl , Signature , Volatility ,
30+ create_udaf, create_udf, Accumulator , ColumnarValue , ExprSchemable ,
31+ LogicalPlanBuilder , ScalarUDF , ScalarUDFImpl , Signature , Volatility ,
2932} ;
3033use rand:: { thread_rng, Rng } ;
34+ use std:: any:: Any ;
3135use std:: iter;
3236use std:: sync:: Arc ;
3337
@@ -494,6 +498,127 @@ async fn test_user_defined_functions_zero_argument() -> Result<()> {
494498 Ok ( ( ) )
495499}
496500
501+ #[ derive( Debug ) ]
502+ struct TakeUDF {
503+ signature : Signature ,
504+ }
505+
506+ impl TakeUDF {
507+ fn new ( ) -> Self {
508+ Self {
509+ signature : Signature :: any ( 3 , Volatility :: Immutable ) ,
510+ }
511+ }
512+ }
513+
514+ /// Implement a ScalarUDFImpl whose return type is a function of the input values
515+ impl ScalarUDFImpl for TakeUDF {
516+ fn as_any ( & self ) -> & dyn Any {
517+ self
518+ }
519+ fn name ( & self ) -> & str {
520+ "take"
521+ }
522+ fn signature ( & self ) -> & Signature {
523+ & self . signature
524+ }
525+ fn return_type ( & self , _args : & [ DataType ] ) -> Result < DataType > {
526+ not_impl_err ! ( "Not called because the return_type_from_exprs is implemented" )
527+ }
528+
529+ /// Thus function returns the type of the first or second argument based on
530+ /// the third argument:
531+ ///
532+ /// 1. If the third argument is '0', return the type of the first argument
533+ /// 2. If the third argument is '1', return the type of the second argument
534+ fn return_type_from_exprs (
535+ & self ,
536+ arg_exprs : & [ Expr ] ,
537+ schema : & DFSchema ,
538+ ) -> Result < DataType > {
539+ if arg_exprs. len ( ) != 3 {
540+ return plan_err ! ( "Expected 3 arguments, got {}." , arg_exprs. len( ) ) ;
541+ }
542+
543+ let take_idx = if let Some ( Expr :: Literal ( ScalarValue :: Int64 ( Some ( idx) ) ) ) =
544+ arg_exprs. get ( 2 )
545+ {
546+ if * idx == 0 || * idx == 1 {
547+ * idx as usize
548+ } else {
549+ return plan_err ! ( "The third argument must be 0 or 1, got: {idx}" ) ;
550+ }
551+ } else {
552+ return plan_err ! (
553+ "The third argument must be a literal of type int64, but got {:?}" ,
554+ arg_exprs. get( 2 )
555+ ) ;
556+ } ;
557+
558+ arg_exprs. get ( take_idx) . unwrap ( ) . get_type ( schema)
559+ }
560+
561+ // The actual implementation rethr
562+ fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
563+ let take_idx = match & args[ 2 ] {
564+ ColumnarValue :: Scalar ( ScalarValue :: Int64 ( Some ( v) ) ) if v < & 2 => * v as usize ,
565+ _ => unreachable ! ( ) ,
566+ } ;
567+ match & args[ take_idx] {
568+ ColumnarValue :: Array ( array) => Ok ( ColumnarValue :: Array ( array. clone ( ) ) ) ,
569+ ColumnarValue :: Scalar ( _) => unimplemented ! ( ) ,
570+ }
571+ }
572+ }
573+
574+ #[ tokio:: test]
575+ async fn verify_udf_return_type ( ) -> Result < ( ) > {
576+ // Create a new ScalarUDF from the implementation
577+ let take = ScalarUDF :: from ( TakeUDF :: new ( ) ) ;
578+
579+ // SELECT
580+ // take(smallint_col, double_col, 0) as take0,
581+ // take(smallint_col, double_col, 1) as take1
582+ // FROM alltypes_plain;
583+ let exprs = vec ! [
584+ take. call( vec![ col( "smallint_col" ) , col( "double_col" ) , lit( 0_i64 ) ] )
585+ . alias( "take0" ) ,
586+ take. call( vec![ col( "smallint_col" ) , col( "double_col" ) , lit( 1_i64 ) ] )
587+ . alias( "take1" ) ,
588+ ] ;
589+
590+ let ctx = SessionContext :: new ( ) ;
591+ register_alltypes_parquet ( & ctx) . await ?;
592+
593+ let df = ctx. table ( "alltypes_plain" ) . await ?. select ( exprs) ?;
594+
595+ let schema = df. schema ( ) ;
596+
597+ // The output schema should be
598+ // * type of column smallint_col (float64)
599+ // * type of column double_col (float32)
600+ assert_eq ! ( schema. field( 0 ) . data_type( ) , & DataType :: Int32 ) ;
601+ assert_eq ! ( schema. field( 1 ) . data_type( ) , & DataType :: Float64 ) ;
602+
603+ let expected = [
604+ "+-------+-------+" ,
605+ "| take0 | take1 |" ,
606+ "+-------+-------+" ,
607+ "| 0 | 0.0 |" ,
608+ "| 0 | 0.0 |" ,
609+ "| 0 | 0.0 |" ,
610+ "| 0 | 0.0 |" ,
611+ "| 1 | 10.1 |" ,
612+ "| 1 | 10.1 |" ,
613+ "| 1 | 10.1 |" ,
614+ "| 1 | 10.1 |" ,
615+ "+-------+-------+" ,
616+ ] ;
617+ assert_batches_sorted_eq ! ( & expected, & df. collect( ) . await ?) ;
618+
619+ Ok ( ( ) )
620+ }
621+
497622fn create_udf_context ( ) -> SessionContext {
498623 let ctx = SessionContext :: new ( ) ;
499624 // register a custom UDF
@@ -531,6 +656,17 @@ async fn register_aggregate_csv(ctx: &SessionContext) -> Result<()> {
531656 Ok ( ( ) )
532657}
533658
659+ async fn register_alltypes_parquet ( ctx : & SessionContext ) -> Result < ( ) > {
660+ let testdata = datafusion:: test_util:: parquet_test_data ( ) ;
661+ ctx. register_parquet (
662+ "alltypes_plain" ,
663+ & format ! ( "{testdata}/alltypes_plain.parquet" ) ,
664+ ParquetReadOptions :: default ( ) ,
665+ )
666+ . await ?;
667+ Ok ( ( ) )
668+ }
669+
534670/// Execute SQL and return results as a RecordBatch
535671async fn plan_and_collect ( ctx : & SessionContext , sql : & str ) -> Result < Vec < RecordBatch > > {
536672 ctx. sql ( sql) . await ?. collect ( ) . await
0 commit comments