1818use arrow:: array:: Array ;
1919use arrow:: datatypes:: { DataType , FieldRef , UnionFields } ;
2020use datafusion_common:: cast:: as_union_array;
21+ use datafusion_common:: utils:: take_function_args;
2122use datafusion_common:: {
2223 exec_datafusion_err, exec_err, internal_err, Result , ScalarValue ,
2324} ;
@@ -113,22 +114,15 @@ impl ScalarUDFImpl for UnionExtractFun {
113114 }
114115
115116 fn invoke_with_args ( & self , args : ScalarFunctionArgs ) -> Result < ColumnarValue > {
116- let args = args. args ;
117+ let [ array , target_name ] = take_function_args ( "union_extract" , args. args ) ? ;
117118
118- if args. len ( ) != 2 {
119- return exec_err ! (
120- "union_extract expects 2 arguments, got {} instead" ,
121- args. len( )
122- ) ;
123- }
124-
125- let target_name = match & args[ 1 ] {
119+ let target_name = match target_name {
126120 ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( Some ( target_name) ) ) => Ok ( target_name) ,
127121 ColumnarValue :: Scalar ( ScalarValue :: Utf8 ( None ) ) => exec_err ! ( "union_extract second argument must be a non-null string literal, got a null instead" ) ,
128- _ => exec_err ! ( "union_extract second argument must be a non-null string literal, got {} instead" , & args [ 1 ] . data_type( ) ) ,
129- } ;
122+ _ => exec_err ! ( "union_extract second argument must be a non-null string literal, got {} instead" , target_name . data_type( ) ) ,
123+ } ? ;
130124
131- match & args [ 0 ] {
125+ match array {
132126 ColumnarValue :: Array ( array) => {
133127 let union_array = as_union_array ( & array) . map_err ( |_| {
134128 exec_datafusion_err ! (
@@ -140,19 +134,16 @@ impl ScalarUDFImpl for UnionExtractFun {
140134 Ok ( ColumnarValue :: Array (
141135 arrow:: compute:: kernels:: union_extract:: union_extract (
142136 union_array,
143- target_name? ,
137+ & target_name,
144138 ) ?,
145139 ) )
146140 }
147141 ColumnarValue :: Scalar ( ScalarValue :: Union ( value, fields, _) ) => {
148- let target_name = target_name?;
149- let ( target_type_id, target) = find_field ( fields, target_name) ?;
142+ let ( target_type_id, target) = find_field ( & fields, & target_name) ?;
150143
151144 let result = match value {
152- Some ( ( type_id, value) ) if target_type_id == * type_id => {
153- * value. clone ( )
154- }
155- _ => ScalarValue :: try_from ( target. data_type ( ) ) ?,
145+ Some ( ( type_id, value) ) if target_type_id == type_id => * value,
146+ _ => ScalarValue :: try_new_null ( target. data_type ( ) ) ?,
156147 } ;
157148
158149 Ok ( ColumnarValue :: Scalar ( result) )
0 commit comments