@@ -20,8 +20,9 @@ use arrow::array::{
2020} ;
2121use arrow:: datatypes:: DataType ;
2222use datafusion_common:: cast:: { as_map_array, as_struct_array} ;
23- use datafusion_common:: { exec_err, ExprSchema , Result , ScalarValue } ;
24- use datafusion_expr:: field_util:: GetFieldAccessSchema ;
23+ use datafusion_common:: {
24+ exec_err, plan_datafusion_err, plan_err, ExprSchema , Result , ScalarValue ,
25+ } ;
2526use datafusion_expr:: { ColumnarValue , Expr , ExprSchemable } ;
2627use datafusion_expr:: { ScalarUDFImpl , Signature , Volatility } ;
2728use std:: any:: Any ;
@@ -104,11 +105,37 @@ impl ScalarUDFImpl for GetFieldFunc {
104105 ) ;
105106 }
106107 } ;
107- let access_schema = GetFieldAccessSchema :: NamedStructField { name : name. clone ( ) } ;
108- let arg_dt = args[ 0 ] . get_type ( schema) ?;
109- access_schema
110- . get_accessed_field ( & arg_dt)
111- . map ( |f| f. data_type ( ) . clone ( ) )
108+ let data_type = args[ 0 ] . get_type ( schema) ?;
109+ match ( data_type, name) {
110+ ( DataType :: Map ( fields, _) , _) => {
111+ match fields. data_type ( ) {
112+ DataType :: Struct ( fields) if fields. len ( ) == 2 => {
113+ // Arrow's MapArray is essentially a ListArray of structs with two columns. They are
114+ // often named "key", and "value", but we don't require any specific naming here;
115+ // instead, we assume that the second columnis the "value" column both here and in
116+ // execution.
117+ let value_field = fields. get ( 1 ) . expect ( "fields should have exactly two members" ) ;
118+ Ok ( value_field. data_type ( ) . clone ( ) )
119+ } ,
120+ _ => plan_err ! ( "Map fields must contain a Struct with exactly 2 fields" ) ,
121+ }
122+ }
123+ ( DataType :: Struct ( fields) , ScalarValue :: Utf8 ( Some ( s) ) ) => {
124+ if s. is_empty ( ) {
125+ plan_err ! (
126+ "Struct based indexed access requires a non empty string"
127+ )
128+ } else {
129+ let field = fields. iter ( ) . find ( |f| f. name ( ) == s) ;
130+ field. ok_or ( plan_datafusion_err ! ( "Field {s} not found in struct" ) ) . map ( |f| f. data_type ( ) . clone ( ) )
131+ }
132+ }
133+ ( DataType :: Struct ( _) , _) => plan_err ! (
134+ "Only UTF8 strings are valid as an indexed field in a struct"
135+ ) ,
136+ ( DataType :: Null , _) => Ok ( DataType :: Null ) ,
137+ ( other, _) => plan_err ! ( "The expression to get an indexed field is only valid for `List`, `Struct`, `Map` or `Null` types, got {other}" ) ,
138+ }
112139 }
113140
114141 fn invoke ( & self , args : & [ ColumnarValue ] ) -> Result < ColumnarValue > {
@@ -175,6 +202,7 @@ impl ScalarUDFImpl for GetFieldFunc {
175202 "get indexed field is only possible on struct with utf8 indexes. \
176203 Tried with {name:?} index"
177204 ) ,
205+ ( DataType :: Null , _) => Ok ( ColumnarValue :: Scalar ( ScalarValue :: Null ) ) ,
178206 ( dt, name) => exec_err ! (
179207 "get indexed field is only possible on lists with int64 indexes or struct \
180208 with utf8 indexes. Tried {dt:?} with {name:?} index"
0 commit comments