1919
2020use crate :: utils:: { get_map_entry_field, make_scalar_function} ;
2121use arrow:: array:: { Array , ArrayRef , ListArray } ;
22- use arrow:: datatypes:: { DataType , Field } ;
22+ use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2323use datafusion_common:: utils:: take_function_args;
24- use datafusion_common:: { cast:: as_map_array, exec_err, Result } ;
24+ use datafusion_common:: { cast:: as_map_array, exec_err, internal_err , Result } ;
2525use datafusion_expr:: {
2626 ArrayFunctionSignature , ColumnarValue , Documentation , ScalarUDFImpl , Signature ,
2727 TypeSignature , Volatility ,
2828} ;
2929use datafusion_macros:: user_doc;
3030use std:: any:: Any ;
31+ use std:: ops:: Deref ;
3132use std:: sync:: Arc ;
3233
3334make_udf_expr_and_func ! (
@@ -91,13 +92,22 @@ impl ScalarUDFImpl for MapValuesFunc {
9192 & self . signature
9293 }
9394
94- fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
95- let [ map_type] = take_function_args ( self . name ( ) , arg_types) ?;
96- let map_fields = get_map_entry_field ( map_type) ?;
97- Ok ( DataType :: List ( Arc :: new ( Field :: new_list_field (
98- map_fields. last ( ) . unwrap ( ) . data_type ( ) . clone ( ) ,
99- true ,
100- ) ) ) )
95+ fn return_type ( & self , _arg_types : & [ DataType ] ) -> Result < DataType > {
96+ internal_err ! ( "return_field_from_args should be used instead" )
97+ }
98+
99+ fn return_field_from_args (
100+ & self ,
101+ args : datafusion_expr:: ReturnFieldArgs ,
102+ ) -> Result < Field > {
103+ let [ map_type] = take_function_args ( self . name ( ) , args. arg_fields ) ?;
104+
105+ Ok ( Field :: new (
106+ self . name ( ) ,
107+ DataType :: List ( get_map_values_field_as_list_field ( map_type. data_type ( ) ) ?) ,
108+ // Nullable if the map is nullable
109+ args. arg_fields . iter ( ) . any ( |x| x. is_nullable ( ) ) ,
110+ ) )
101111 }
102112
103113 fn invoke_with_args (
@@ -121,9 +131,137 @@ fn map_values_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
121131 } ;
122132
123133 Ok ( Arc :: new ( ListArray :: new (
124- Arc :: new ( Field :: new_list_field ( map_array . value_type ( ) . clone ( ) , true ) ) ,
134+ get_map_values_field_as_list_field ( map_arg . data_type ( ) ) ? ,
125135 map_array. offsets ( ) . clone ( ) ,
126136 Arc :: clone ( map_array. values ( ) ) ,
127137 map_array. nulls ( ) . cloned ( ) ,
128138 ) ) )
129139}
140+
141+ fn get_map_values_field_as_list_field ( map_type : & DataType ) -> Result < FieldRef > {
142+ let map_fields = get_map_entry_field ( map_type) ?;
143+
144+ let values_field = map_fields
145+ . last ( )
146+ . unwrap ( )
147+ . deref ( )
148+ . clone ( )
149+ . with_name ( Field :: LIST_FIELD_DEFAULT_NAME ) ;
150+
151+ Ok ( Arc :: new ( values_field) )
152+ }
153+
154+ #[ cfg( test) ]
155+ mod tests {
156+ use crate :: map_values:: MapValuesFunc ;
157+ use arrow:: datatypes:: { DataType , Field } ;
158+ use datafusion_common:: ScalarValue ;
159+ use datafusion_expr:: ScalarUDFImpl ;
160+ use std:: sync:: Arc ;
161+
162+ #[ test]
163+ fn return_type_field ( ) {
164+ fn get_map_field (
165+ is_map_nullable : bool ,
166+ is_keys_nullable : bool ,
167+ is_values_nullable : bool ,
168+ ) -> Field {
169+ Field :: new_map (
170+ "something" ,
171+ "entries" ,
172+ Arc :: new ( Field :: new ( "keys" , DataType :: Utf8 , is_keys_nullable) ) ,
173+ Arc :: new ( Field :: new (
174+ "values" ,
175+ DataType :: LargeUtf8 ,
176+ is_values_nullable,
177+ ) ) ,
178+ false ,
179+ is_map_nullable,
180+ )
181+ }
182+
183+ fn get_list_field (
184+ name : & str ,
185+ is_list_nullable : bool ,
186+ list_item_type : DataType ,
187+ is_list_items_nullable : bool ,
188+ ) -> Field {
189+ Field :: new_list (
190+ name,
191+ Arc :: new ( Field :: new_list_field (
192+ list_item_type,
193+ is_list_items_nullable,
194+ ) ) ,
195+ is_list_nullable,
196+ )
197+ }
198+
199+ fn get_return_field ( field : Field ) -> Field {
200+ let func = MapValuesFunc :: new ( ) ;
201+ let args = datafusion_expr:: ReturnFieldArgs {
202+ arg_fields : & [ field] ,
203+ scalar_arguments : & [ None :: < & ScalarValue > ] ,
204+ } ;
205+
206+ func. return_field_from_args ( args) . unwrap ( )
207+ }
208+
209+ // Test cases:
210+ //
211+ // | Input Map || Expected Output |
212+ // | ------------------------------------------------------ || ----------------------------------------------------- |
213+ // | map nullable | map keys nullable | map values nullable || expected list nullable | expected list items nullable |
214+ // | ------------ | ----------------- | ------------------- || ---------------------- | ---------------------------- |
215+ // | false | false | false || false | false |
216+ // | false | false | true || false | true |
217+ // | false | true | false || false | false |
218+ // | false | true | true || false | true |
219+ // | true | false | false || true | false |
220+ // | true | false | true || true | true |
221+ // | true | true | false || true | false |
222+ // | true | true | true || true | true |
223+ //
224+ // ---------------
225+ // We added the key nullability to show that it does not affect the nullability of the list or the list items.
226+
227+ assert_eq ! (
228+ get_return_field( get_map_field( false , false , false ) ) ,
229+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , false )
230+ ) ;
231+
232+ assert_eq ! (
233+ get_return_field( get_map_field( false , false , true ) ) ,
234+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , true )
235+ ) ;
236+
237+ assert_eq ! (
238+ get_return_field( get_map_field( false , true , false ) ) ,
239+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , false )
240+ ) ;
241+
242+ assert_eq ! (
243+ get_return_field( get_map_field( false , true , true ) ) ,
244+ get_list_field( "map_values" , false , DataType :: LargeUtf8 , true )
245+ ) ;
246+
247+ assert_eq ! (
248+ get_return_field( get_map_field( true , false , false ) ) ,
249+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , false )
250+ ) ;
251+
252+ assert_eq ! (
253+ get_return_field( get_map_field( true , false , true ) ) ,
254+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , true )
255+ ) ;
256+
257+ assert_eq ! (
258+ get_return_field( get_map_field( true , true , false ) ) ,
259+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , false )
260+ ) ;
261+
262+ assert_eq ! (
263+ get_return_field( get_map_field( true , true , true ) ) ,
264+ get_list_field( "map_values" , true , DataType :: LargeUtf8 , true )
265+ ) ;
266+ }
267+ }
0 commit comments