1717
1818//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
1919
20- use crate :: make_array:: { empty_array_type, make_array_inner} ;
2120use crate :: utils:: make_scalar_function;
22- use arrow:: array:: { new_empty_array , Array , ArrayRef , GenericListArray , OffsetSizeTrait } ;
21+ use arrow:: array:: { Array , ArrayRef , GenericListArray , OffsetSizeTrait } ;
2322use arrow:: buffer:: OffsetBuffer ;
2423use arrow:: compute;
2524use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2625use arrow:: row:: { RowConverter , SortField } ;
26+ use arrow_array:: { new_null_array, LargeListArray , ListArray } ;
2727use arrow_schema:: DataType :: { FixedSizeList , LargeList , List , Null } ;
2828use datafusion_common:: cast:: { as_large_list_array, as_list_array} ;
29- use datafusion_common:: { exec_err, internal_err, Result } ;
29+ use datafusion_common:: { exec_err, internal_err, plan_err , Result } ;
3030use datafusion_expr:: { ColumnarValue , ScalarUDFImpl , Signature , Volatility } ;
3131use itertools:: Itertools ;
3232use std:: any:: Any ;
@@ -89,7 +89,8 @@ impl ScalarUDFImpl for ArrayUnion {
8989
9090 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
9191 match ( & arg_types[ 0 ] , & arg_types[ 1 ] ) {
92- ( & Null , dt) => Ok ( dt. clone ( ) ) ,
92+ ( Null , Null ) => Ok ( DataType :: new_list ( Null , true ) ) ,
93+ ( Null , dt) => Ok ( dt. clone ( ) ) ,
9394 ( dt, Null ) => Ok ( dt. clone ( ) ) ,
9495 ( dt, _) => Ok ( dt. clone ( ) ) ,
9596 }
@@ -134,9 +135,10 @@ impl ScalarUDFImpl for ArrayIntersect {
134135
135136 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
136137 match ( arg_types[ 0 ] . clone ( ) , arg_types[ 1 ] . clone ( ) ) {
137- ( Null , Null ) | ( Null , _) => Ok ( Null ) ,
138- ( _, Null ) => Ok ( empty_array_type ( ) ) ,
139- ( dt, _) => Ok ( dt) ,
138+ ( Null , Null ) => Ok ( DataType :: new_list ( Null , true ) ) ,
139+ ( Null , dt) => Ok ( dt. clone ( ) ) ,
140+ ( dt, Null ) => Ok ( dt. clone ( ) ) ,
141+ ( dt, _) => Ok ( dt. clone ( ) ) ,
140142 }
141143 }
142144
@@ -179,19 +181,13 @@ impl ScalarUDFImpl for ArrayDistinct {
179181
180182 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
181183 match & arg_types[ 0 ] {
182- List ( field) | FixedSizeList ( field, _) => Ok ( List ( Arc :: new ( Field :: new (
183- "item" ,
184- field. data_type ( ) . clone ( ) ,
185- true ,
186- ) ) ) ) ,
187- LargeList ( field) => Ok ( LargeList ( Arc :: new ( Field :: new (
188- "item" ,
189- field. data_type ( ) . clone ( ) ,
190- true ,
191- ) ) ) ) ,
192- _ => exec_err ! (
193- "Not reachable, data_type should be List, LargeList or FixedSizeList"
194- ) ,
184+ List ( field) | FixedSizeList ( field, _) => {
185+ Ok ( DataType :: new_list ( field. data_type ( ) . clone ( ) , true ) )
186+ }
187+ LargeList ( field) => {
188+ Ok ( DataType :: new_large_list ( field. data_type ( ) . clone ( ) , true ) )
189+ }
190+ arg_type => plan_err ! ( "{} does not support type {arg_type}" , self . name( ) ) ,
195191 }
196192 }
197193
@@ -211,22 +207,18 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
211207 return exec_err ! ( "array_distinct needs one argument" ) ;
212208 }
213209
214- // handle null
215- if args[ 0 ] . data_type ( ) == & Null {
216- return Ok ( args[ 0 ] . clone ( ) ) ;
217- }
218-
219- // handle for list & largelist
220- match args[ 0 ] . data_type ( ) {
210+ let array = & args[ 0 ] ;
211+ match array. data_type ( ) {
212+ Null => Ok ( Arc :: clone ( array) ) ,
221213 List ( field) => {
222- let array = as_list_array ( & args [ 0 ] ) ?;
214+ let array = as_list_array ( array ) ?;
223215 general_array_distinct ( array, field)
224216 }
225217 LargeList ( field) => {
226- let array = as_large_list_array ( & args [ 0 ] ) ?;
218+ let array = as_large_list_array ( array ) ?;
227219 general_array_distinct ( array, field)
228220 }
229- array_type => exec_err ! ( "array_distinct does not support type '{array_type :?}'" ) ,
221+ arg_type => exec_err ! ( "array_distinct does not support type '{arg_type :?}'" ) ,
230222 }
231223}
232224
@@ -251,80 +243,69 @@ fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
251243 field : Arc < Field > ,
252244 set_op : SetOp ,
253245) -> Result < ArrayRef > {
254- if matches ! ( l. value_type( ) , Null ) {
246+ if l . is_empty ( ) || l. value_type ( ) . is_null ( ) {
255247 let field = Arc :: new ( Field :: new ( "item" , r. value_type ( ) , true ) ) ;
256248 return general_array_distinct :: < OffsetSize > ( r, & field) ;
257- } else if matches ! ( r. value_type( ) , Null ) {
249+ } else if r . is_empty ( ) || r. value_type ( ) . is_null ( ) {
258250 let field = Arc :: new ( Field :: new ( "item" , l. value_type ( ) , true ) ) ;
259251 return general_array_distinct :: < OffsetSize > ( l, & field) ;
260252 }
261253
262- // Handle empty array at rhs case
263- // array_union(arr, []) -> arr;
264- // array_intersect(arr, []) -> [];
265- if r. value_length ( 0 ) . is_zero ( ) {
266- if set_op == SetOp :: Union {
267- return Ok ( Arc :: new ( l. clone ( ) ) as ArrayRef ) ;
268- } else {
269- return Ok ( Arc :: new ( r. clone ( ) ) as ArrayRef ) ;
270- }
271- }
272-
273254 if l. value_type ( ) != r. value_type ( ) {
274- return internal_err ! ( "{set_op:?} is not implemented for '{l:?}' and '{r:?}'" ) ;
255+ return internal_err ! (
256+ "{set_op} is not implemented for {} and {}" ,
257+ l. data_type( ) ,
258+ r. data_type( )
259+ ) ;
275260 }
276261
277- let dt = l. value_type ( ) ;
278-
279262 let mut offsets = vec ! [ OffsetSize :: usize_as( 0 ) ] ;
280263 let mut new_arrays = vec ! [ ] ;
281-
282- let converter = RowConverter :: new ( vec ! [ SortField :: new( dt) ] ) ?;
264+ let converter = RowConverter :: new ( vec ! [ SortField :: new( l. value_type( ) ) ] ) ?;
283265 for ( first_arr, second_arr) in l. iter ( ) . zip ( r. iter ( ) ) {
284- if let ( Some ( first_arr) , Some ( second_arr) ) = ( first_arr, second_arr) {
285- let l_values = converter. convert_columns ( & [ first_arr] ) ?;
286- let r_values = converter. convert_columns ( & [ second_arr] ) ?;
287-
288- let l_iter = l_values. iter ( ) . sorted ( ) . dedup ( ) ;
289- let values_set: HashSet < _ > = l_iter. clone ( ) . collect ( ) ;
290- let mut rows = if set_op == SetOp :: Union {
291- l_iter. collect :: < Vec < _ > > ( )
292- } else {
293- vec ! [ ]
294- } ;
295- for r_val in r_values. iter ( ) . sorted ( ) . dedup ( ) {
296- match set_op {
297- SetOp :: Union => {
298- if !values_set. contains ( & r_val) {
299- rows. push ( r_val) ;
300- }
301- }
302- SetOp :: Intersect => {
303- if values_set. contains ( & r_val) {
304- rows. push ( r_val) ;
305- }
306- }
307- }
308- }
266+ let l_values = if let Some ( first_arr) = first_arr {
267+ converter. convert_columns ( & [ first_arr] ) ?
268+ } else {
269+ converter. convert_columns ( & [ ] ) ?
270+ } ;
271+
272+ let r_values = if let Some ( second_arr) = second_arr {
273+ converter. convert_columns ( & [ second_arr] ) ?
274+ } else {
275+ converter. convert_columns ( & [ ] ) ?
276+ } ;
277+
278+ let l_iter = l_values. iter ( ) . sorted ( ) . dedup ( ) ;
279+ let values_set: HashSet < _ > = l_iter. clone ( ) . collect ( ) ;
280+ let mut rows = if set_op == SetOp :: Union {
281+ l_iter. collect ( )
282+ } else {
283+ vec ! [ ]
284+ } ;
309285
310- let last_offset = match offsets. last ( ) . copied ( ) {
311- Some ( offset) => offset,
312- None => return internal_err ! ( "offsets should not be empty" ) ,
313- } ;
314- offsets. push ( last_offset + OffsetSize :: usize_as ( rows. len ( ) ) ) ;
315- let arrays = converter. convert_rows ( rows) ?;
316- let array = match arrays. first ( ) {
317- Some ( array) => array. clone ( ) ,
318- None => {
319- return internal_err ! ( "{set_op}: failed to get array from rows" ) ;
320- }
321- } ;
322- new_arrays. push ( array) ;
286+ for r_val in r_values. iter ( ) . sorted ( ) . dedup ( ) {
287+ match set_op {
288+ SetOp :: Union if !values_set. contains ( & r_val) => rows. push ( r_val) ,
289+ SetOp :: Intersect if values_set. contains ( & r_val) => rows. push ( r_val) ,
290+ _ => ( ) ,
291+ }
323292 }
293+
294+ let last_offset = match offsets. last ( ) {
295+ Some ( offset) => * offset,
296+ None => return internal_err ! ( "offsets should not be empty" ) ,
297+ } ;
298+
299+ offsets. push ( last_offset + OffsetSize :: usize_as ( rows. len ( ) ) ) ;
300+ let arrays = converter. convert_rows ( rows) ?;
301+ new_arrays. push ( match arrays. first ( ) {
302+ Some ( array) => Arc :: clone ( array) ,
303+ None => return internal_err ! ( "{set_op}: failed to get array from rows" ) ,
304+ } ) ;
324305 }
325306
326307 let offsets = OffsetBuffer :: new ( offsets. into ( ) ) ;
327- let new_arrays_ref = new_arrays. iter ( ) . map ( |v| v. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
308+ let new_arrays_ref: Vec < _ > = new_arrays. iter ( ) . map ( |v| v. as_ref ( ) ) . collect ( ) ;
328309 let values = compute:: concat ( & new_arrays_ref) ?;
329310 let arr = GenericListArray :: < OffsetSize > :: try_new ( field, offsets, values, None ) ?;
330311 Ok ( Arc :: new ( arr) )
@@ -335,38 +316,60 @@ fn general_set_op(
335316 array2 : & ArrayRef ,
336317 set_op : SetOp ,
337318) -> Result < ArrayRef > {
319+ fn empty_array ( data_type : & DataType , len : usize , large : bool ) -> Result < ArrayRef > {
320+ let field = Arc :: new ( Field :: new_list_field ( data_type. clone ( ) , true ) ) ;
321+ let values = new_null_array ( data_type, len) ;
322+ if large {
323+ Ok ( Arc :: new ( LargeListArray :: try_new (
324+ field,
325+ OffsetBuffer :: new_zeroed ( len) ,
326+ values,
327+ None ,
328+ ) ?) )
329+ } else {
330+ Ok ( Arc :: new ( ListArray :: try_new (
331+ field,
332+ OffsetBuffer :: new_zeroed ( len) ,
333+ values,
334+ None ,
335+ ) ?) )
336+ }
337+ }
338+
338339 match ( array1. data_type ( ) , array2. data_type ( ) ) {
340+ ( Null , Null ) => Ok ( Arc :: new ( ListArray :: new_null (
341+ Arc :: new ( Field :: new_list_field ( Null , true ) ) ,
342+ array1. len ( ) ,
343+ ) ) ) ,
339344 ( Null , List ( field) ) => {
340345 if set_op == SetOp :: Intersect {
341- return Ok ( new_empty_array ( & Null ) ) ;
346+ return empty_array ( field . data_type ( ) , array1 . len ( ) , false ) ;
342347 }
343348 let array = as_list_array ( & array2) ?;
344349 general_array_distinct :: < i32 > ( array, field)
345350 }
346351
347352 ( List ( field) , Null ) => {
348353 if set_op == SetOp :: Intersect {
349- return make_array_inner ( & [ ] ) ;
354+ return empty_array ( field . data_type ( ) , array1 . len ( ) , false ) ;
350355 }
351356 let array = as_list_array ( & array1) ?;
352357 general_array_distinct :: < i32 > ( array, field)
353358 }
354359 ( Null , LargeList ( field) ) => {
355360 if set_op == SetOp :: Intersect {
356- return Ok ( new_empty_array ( & Null ) ) ;
361+ return empty_array ( field . data_type ( ) , array1 . len ( ) , true ) ;
357362 }
358363 let array = as_large_list_array ( & array2) ?;
359364 general_array_distinct :: < i64 > ( array, field)
360365 }
361366 ( LargeList ( field) , Null ) => {
362367 if set_op == SetOp :: Intersect {
363- return make_array_inner ( & [ ] ) ;
368+ return empty_array ( field . data_type ( ) , array1 . len ( ) , true ) ;
364369 }
365370 let array = as_large_list_array ( & array1) ?;
366371 general_array_distinct :: < i64 > ( array, field)
367372 }
368- ( Null , Null ) => Ok ( new_empty_array ( & Null ) ) ,
369-
370373 ( List ( field) , List ( _) ) => {
371374 let array1 = as_list_array ( & array1) ?;
372375 let array2 = as_list_array ( & array2) ?;
0 commit comments