1717
1818//! [`ScalarUDFImpl`] definitions for array_union, array_intersect and array_distinct functions.
1919
20- use crate :: make_array:: { empty_array_type, make_array_inner} ;
2120use crate :: utils:: make_scalar_function;
22- use arrow:: array:: { new_empty_array , Array , ArrayRef , GenericListArray , OffsetSizeTrait } ;
21+ use arrow:: array:: { Array , ArrayRef , GenericListArray , OffsetSizeTrait } ;
2322use arrow:: buffer:: OffsetBuffer ;
2423use arrow:: compute;
2524use arrow:: datatypes:: { DataType , Field , FieldRef } ;
2625use arrow:: row:: { RowConverter , SortField } ;
26+ use arrow_array:: { new_null_array, LargeListArray , ListArray } ;
2727use arrow_schema:: DataType :: { FixedSizeList , LargeList , List , Null } ;
2828use datafusion_common:: cast:: { as_large_list_array, as_list_array} ;
29- use datafusion_common:: { exec_err, internal_err, Result } ;
29+ use datafusion_common:: { exec_err, internal_err, plan_err , Result } ;
3030use datafusion_expr:: scalar_doc_sections:: DOC_SECTION_ARRAY ;
3131use datafusion_expr:: {
3232 ColumnarValue , Documentation , ScalarUDFImpl , Signature , Volatility ,
@@ -92,7 +92,8 @@ impl ScalarUDFImpl for ArrayUnion {
9292
9393 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
9494 match ( & arg_types[ 0 ] , & arg_types[ 1 ] ) {
95- ( & Null , dt) => Ok ( dt. clone ( ) ) ,
95+ ( Null , Null ) => Ok ( DataType :: new_list ( Null , true ) ) ,
96+ ( Null , dt) => Ok ( dt. clone ( ) ) ,
9697 ( dt, Null ) => Ok ( dt. clone ( ) ) ,
9798 ( dt, _) => Ok ( dt. clone ( ) ) ,
9899 }
@@ -182,9 +183,10 @@ impl ScalarUDFImpl for ArrayIntersect {
182183
183184 fn return_type ( & self , arg_types : & [ DataType ] ) -> Result < DataType > {
184185 match ( arg_types[ 0 ] . clone ( ) , arg_types[ 1 ] . clone ( ) ) {
185- ( Null , Null ) | ( Null , _) => Ok ( Null ) ,
186- ( _, Null ) => Ok ( empty_array_type ( ) ) ,
187- ( dt, _) => Ok ( dt) ,
186+ ( Null , Null ) => Ok ( DataType :: new_list ( Null , true ) ) ,
187+ ( Null , dt) => Ok ( dt. clone ( ) ) ,
188+ ( dt, Null ) => Ok ( dt. clone ( ) ) ,
189+ ( dt, _) => Ok ( dt. clone ( ) ) ,
188190 }
189191 }
190192
@@ -332,22 +334,18 @@ fn array_distinct_inner(args: &[ArrayRef]) -> Result<ArrayRef> {
332334 return exec_err ! ( "array_distinct needs one argument" ) ;
333335 }
334336
335- // handle null
336- if args[ 0 ] . data_type ( ) == & Null {
337- return Ok ( Arc :: clone ( & args[ 0 ] ) ) ;
338- }
339-
340- // handle for list & largelist
341- match args[ 0 ] . data_type ( ) {
337+ let array = & args[ 0 ] ;
338+ match array. data_type ( ) {
339+ Null => Ok ( Arc :: clone ( array) ) ,
342340 List ( field) => {
343- let array = as_list_array ( & args [ 0 ] ) ?;
341+ let array = as_list_array ( array ) ?;
344342 general_array_distinct ( array, field)
345343 }
346344 LargeList ( field) => {
347- let array = as_large_list_array ( & args [ 0 ] ) ?;
345+ let array = as_large_list_array ( array ) ?;
348346 general_array_distinct ( array, field)
349347 }
350- array_type => exec_err ! ( "array_distinct does not support type '{array_type :?}'" ) ,
348+ arg_type => exec_err ! ( "array_distinct does not support type '{arg_type :?}'" ) ,
351349 }
352350}
353351
@@ -372,80 +370,69 @@ fn generic_set_lists<OffsetSize: OffsetSizeTrait>(
372370 field : Arc < Field > ,
373371 set_op : SetOp ,
374372) -> Result < ArrayRef > {
375- if matches ! ( l. value_type( ) , Null ) {
373+ if l . is_empty ( ) || l. value_type ( ) . is_null ( ) {
376374 let field = Arc :: new ( Field :: new_list_field ( r. value_type ( ) , true ) ) ;
377375 return general_array_distinct :: < OffsetSize > ( r, & field) ;
378- } else if matches ! ( r. value_type( ) , Null ) {
376+ } else if r . is_empty ( ) || r. value_type ( ) . is_null ( ) {
379377 let field = Arc :: new ( Field :: new_list_field ( l. value_type ( ) , true ) ) ;
380378 return general_array_distinct :: < OffsetSize > ( l, & field) ;
381379 }
382380
383- // Handle empty array at rhs case
384- // array_union(arr, []) -> arr;
385- // array_intersect(arr, []) -> [];
386- if r. value_length ( 0 ) . is_zero ( ) {
387- if set_op == SetOp :: Union {
388- return Ok ( Arc :: new ( l. clone ( ) ) as ArrayRef ) ;
389- } else {
390- return Ok ( Arc :: new ( r. clone ( ) ) as ArrayRef ) ;
391- }
392- }
393-
394381 if l. value_type ( ) != r. value_type ( ) {
395- return internal_err ! ( "{set_op:?} is not implemented for '{l:?}' and '{r:?}'" ) ;
382+ return internal_err ! (
383+ "{set_op} is not implemented for {} and {}" ,
384+ l. data_type( ) ,
385+ r. data_type( )
386+ ) ;
396387 }
397388
398- let dt = l. value_type ( ) ;
399-
400389 let mut offsets = vec ! [ OffsetSize :: usize_as( 0 ) ] ;
401390 let mut new_arrays = vec ! [ ] ;
402-
403- let converter = RowConverter :: new ( vec ! [ SortField :: new( dt) ] ) ?;
391+ let converter = RowConverter :: new ( vec ! [ SortField :: new( l. value_type( ) ) ] ) ?;
404392 for ( first_arr, second_arr) in l. iter ( ) . zip ( r. iter ( ) ) {
405- if let ( Some ( first_arr) , Some ( second_arr) ) = ( first_arr, second_arr) {
406- let l_values = converter. convert_columns ( & [ first_arr] ) ?;
407- let r_values = converter. convert_columns ( & [ second_arr] ) ?;
408-
409- let l_iter = l_values. iter ( ) . sorted ( ) . dedup ( ) ;
410- let values_set: HashSet < _ > = l_iter. clone ( ) . collect ( ) ;
411- let mut rows = if set_op == SetOp :: Union {
412- l_iter. collect :: < Vec < _ > > ( )
413- } else {
414- vec ! [ ]
415- } ;
416- for r_val in r_values. iter ( ) . sorted ( ) . dedup ( ) {
417- match set_op {
418- SetOp :: Union => {
419- if !values_set. contains ( & r_val) {
420- rows. push ( r_val) ;
421- }
422- }
423- SetOp :: Intersect => {
424- if values_set. contains ( & r_val) {
425- rows. push ( r_val) ;
426- }
427- }
428- }
429- }
393+ let l_values = if let Some ( first_arr) = first_arr {
394+ converter. convert_columns ( & [ first_arr] ) ?
395+ } else {
396+ converter. convert_columns ( & [ ] ) ?
397+ } ;
398+
399+ let r_values = if let Some ( second_arr) = second_arr {
400+ converter. convert_columns ( & [ second_arr] ) ?
401+ } else {
402+ converter. convert_columns ( & [ ] ) ?
403+ } ;
430404
431- let last_offset = match offsets. last ( ) . copied ( ) {
432- Some ( offset) => offset,
433- None => return internal_err ! ( "offsets should not be empty" ) ,
434- } ;
435- offsets. push ( last_offset + OffsetSize :: usize_as ( rows. len ( ) ) ) ;
436- let arrays = converter. convert_rows ( rows) ?;
437- let array = match arrays. first ( ) {
438- Some ( array) => Arc :: clone ( array) ,
439- None => {
440- return internal_err ! ( "{set_op}: failed to get array from rows" ) ;
441- }
442- } ;
443- new_arrays. push ( array) ;
405+ let l_iter = l_values. iter ( ) . sorted ( ) . dedup ( ) ;
406+ let values_set: HashSet < _ > = l_iter. clone ( ) . collect ( ) ;
407+ let mut rows = if set_op == SetOp :: Union {
408+ l_iter. collect ( )
409+ } else {
410+ vec ! [ ]
411+ } ;
412+
413+ for r_val in r_values. iter ( ) . sorted ( ) . dedup ( ) {
414+ match set_op {
415+ SetOp :: Union if !values_set. contains ( & r_val) => rows. push ( r_val) ,
416+ SetOp :: Intersect if values_set. contains ( & r_val) => rows. push ( r_val) ,
417+ _ => ( ) ,
418+ }
444419 }
420+
421+ let last_offset = match offsets. last ( ) {
422+ Some ( offset) => * offset,
423+ None => return internal_err ! ( "offsets should not be empty" ) ,
424+ } ;
425+
426+ offsets. push ( last_offset + OffsetSize :: usize_as ( rows. len ( ) ) ) ;
427+ let arrays = converter. convert_rows ( rows) ?;
428+ new_arrays. push ( match arrays. first ( ) {
429+ Some ( array) => Arc :: clone ( array) ,
430+ None => return internal_err ! ( "{set_op}: failed to get array from rows" ) ,
431+ } ) ;
445432 }
446433
447434 let offsets = OffsetBuffer :: new ( offsets. into ( ) ) ;
448- let new_arrays_ref = new_arrays. iter ( ) . map ( |v| v. as_ref ( ) ) . collect :: < Vec < _ > > ( ) ;
435+ let new_arrays_ref: Vec < _ > = new_arrays. iter ( ) . map ( |v| v. as_ref ( ) ) . collect ( ) ;
449436 let values = compute:: concat ( & new_arrays_ref) ?;
450437 let arr = GenericListArray :: < OffsetSize > :: try_new ( field, offsets, values, None ) ?;
451438 Ok ( Arc :: new ( arr) )
@@ -456,38 +443,60 @@ fn general_set_op(
456443 array2 : & ArrayRef ,
457444 set_op : SetOp ,
458445) -> Result < ArrayRef > {
446+ fn empty_array ( data_type : & DataType , len : usize , large : bool ) -> Result < ArrayRef > {
447+ let field = Arc :: new ( Field :: new_list_field ( data_type. clone ( ) , true ) ) ;
448+ let values = new_null_array ( data_type, len) ;
449+ if large {
450+ Ok ( Arc :: new ( LargeListArray :: try_new (
451+ field,
452+ OffsetBuffer :: new_zeroed ( len) ,
453+ values,
454+ None ,
455+ ) ?) )
456+ } else {
457+ Ok ( Arc :: new ( ListArray :: try_new (
458+ field,
459+ OffsetBuffer :: new_zeroed ( len) ,
460+ values,
461+ None ,
462+ ) ?) )
463+ }
464+ }
465+
459466 match ( array1. data_type ( ) , array2. data_type ( ) ) {
467+ ( Null , Null ) => Ok ( Arc :: new ( ListArray :: new_null (
468+ Arc :: new ( Field :: new_list_field ( Null , true ) ) ,
469+ array1. len ( ) ,
470+ ) ) ) ,
460471 ( Null , List ( field) ) => {
461472 if set_op == SetOp :: Intersect {
462- return Ok ( new_empty_array ( & Null ) ) ;
473+ return empty_array ( field . data_type ( ) , array1 . len ( ) , false ) ;
463474 }
464475 let array = as_list_array ( & array2) ?;
465476 general_array_distinct :: < i32 > ( array, field)
466477 }
467478
468479 ( List ( field) , Null ) => {
469480 if set_op == SetOp :: Intersect {
470- return make_array_inner ( & [ ] ) ;
481+ return empty_array ( field . data_type ( ) , array1 . len ( ) , false ) ;
471482 }
472483 let array = as_list_array ( & array1) ?;
473484 general_array_distinct :: < i32 > ( array, field)
474485 }
475486 ( Null , LargeList ( field) ) => {
476487 if set_op == SetOp :: Intersect {
477- return Ok ( new_empty_array ( & Null ) ) ;
488+ return empty_array ( field . data_type ( ) , array1 . len ( ) , true ) ;
478489 }
479490 let array = as_large_list_array ( & array2) ?;
480491 general_array_distinct :: < i64 > ( array, field)
481492 }
482493 ( LargeList ( field) , Null ) => {
483494 if set_op == SetOp :: Intersect {
484- return make_array_inner ( & [ ] ) ;
495+ return empty_array ( field . data_type ( ) , array1 . len ( ) , true ) ;
485496 }
486497 let array = as_large_list_array ( & array1) ?;
487498 general_array_distinct :: < i64 > ( array, field)
488499 }
489- ( Null , Null ) => Ok ( new_empty_array ( & Null ) ) ,
490-
491500 ( List ( field) , List ( _) ) => {
492501 let array1 = as_list_array ( & array1) ?;
493502 let array2 = as_list_array ( & array2) ?;
0 commit comments