@@ -33,7 +33,7 @@ use datafusion_common::cast::{
3333 as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
3434 as_null_array, as_string_array,
3535} ;
36- use datafusion_common:: utils:: array_into_list_array;
36+ use datafusion_common:: utils:: { array_into_list_array, list_ndims } ;
3737use datafusion_common:: {
3838 exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
3939 DataFusionError , Result ,
@@ -103,6 +103,7 @@ fn compare_element_to_list(
103103) -> Result < BooleanArray > {
104104 let indices = UInt32Array :: from ( vec ! [ row_index as u32 ] ) ;
105105 let element_array_row = arrow:: compute:: take ( element_array, & indices, None ) ?;
106+
106107 // Compute all positions in list_row_array (that is itself an
107108 // array) that are equal to `from_array_row`
108109 let res = match element_array_row. data_type ( ) {
@@ -176,35 +177,6 @@ fn compute_array_length(
176177 }
177178}
178179
179- /// Returns the dimension of the array
180- fn compute_array_ndims ( arr : Option < ArrayRef > ) -> Result < Option < u64 > > {
181- Ok ( compute_array_ndims_with_datatype ( arr) ?. 0 )
182- }
183-
184- /// Returns the dimension and the datatype of elements of the array
185- fn compute_array_ndims_with_datatype (
186- arr : Option < ArrayRef > ,
187- ) -> Result < ( Option < u64 > , DataType ) > {
188- let mut res: u64 = 1 ;
189- let mut value = match arr {
190- Some ( arr) => arr,
191- None => return Ok ( ( None , DataType :: Null ) ) ,
192- } ;
193- if value. is_empty ( ) {
194- return Ok ( ( None , DataType :: Null ) ) ;
195- }
196-
197- loop {
198- match value. data_type ( ) {
199- DataType :: List ( ..) => {
200- value = downcast_arg ! ( value, ListArray ) . value ( 0 ) ;
201- res += 1 ;
202- }
203- data_type => return Ok ( ( Some ( res) , data_type. clone ( ) ) ) ,
204- }
205- }
206- }
207-
208180/// Returns the length of each array dimension
209181fn compute_array_dims ( arr : Option < ArrayRef > ) -> Result < Option < Vec < Option < u64 > > > > {
210182 let mut value = match arr {
@@ -825,10 +797,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
825797fn align_array_dimensions ( args : Vec < ArrayRef > ) -> Result < Vec < ArrayRef > > {
826798 let args_ndim = args
827799 . iter ( )
828- . map ( |arg| compute_array_ndims ( Some ( arg. to_owned ( ) ) ) )
829- . collect :: < Result < Vec < _ > > > ( ) ?
830- . into_iter ( )
831- . map ( |x| x. unwrap_or ( 0 ) )
800+ . map ( |arg| datafusion_common:: utils:: list_ndims ( arg. data_type ( ) ) )
832801 . collect :: < Vec < _ > > ( ) ;
833802 let max_ndim = args_ndim. iter ( ) . max ( ) . unwrap_or ( & 0 ) ;
834803
@@ -919,18 +888,19 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
919888 Arc :: new ( compute:: concat ( elements. as_slice ( ) ) ?) ,
920889 Some ( NullBuffer :: new ( buffer) ) ,
921890 ) ;
891+
922892 Ok ( Arc :: new ( list_arr) )
923893}
924894
925895/// Array_concat/Array_cat SQL function
926896pub fn array_concat ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
927897 let mut new_args = vec ! [ ] ;
928898 for arg in args {
929- let ( ndim, lower_data_type ) =
930- compute_array_ndims_with_datatype ( Some ( arg. clone ( ) ) ) ? ;
931- if ndim. is_none ( ) || ndim == Some ( 1 ) {
932- return not_impl_err ! ( "Array is not type '{lower_data_type :?}'." ) ;
933- } else if !lower_data_type . equals_datatype ( & DataType :: Null ) {
899+ let ndim = list_ndims ( arg . data_type ( ) ) ;
900+ let base_type = datafusion_common :: utils :: base_type ( arg. data_type ( ) ) ;
901+ if ndim == 0 {
902+ return not_impl_err ! ( "Array is not type '{base_type :?}'." ) ;
903+ } else if !base_type . eq ( & DataType :: Null ) {
934904 new_args. push ( arg. clone ( ) ) ;
935905 }
936906 }
@@ -1765,14 +1735,22 @@ pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {
17651735
17661736/// Array_ndims SQL function
17671737pub fn array_ndims ( args : & [ ArrayRef ] ) -> Result < ArrayRef > {
1768- let list_array = as_list_array ( & args[ 0 ] ) ?;
1738+ if let Some ( list_array) = args[ 0 ] . as_list_opt :: < i32 > ( ) {
1739+ let ndims = datafusion_common:: utils:: list_ndims ( list_array. data_type ( ) ) ;
17691740
1770- let result = list_array
1771- . iter ( )
1772- . map ( compute_array_ndims)
1773- . collect :: < Result < UInt64Array > > ( ) ?;
1741+ let mut data = vec ! [ ] ;
1742+ for arr in list_array. iter ( ) {
1743+ if arr. is_some ( ) {
1744+ data. push ( Some ( ndims) )
1745+ } else {
1746+ data. push ( None )
1747+ }
1748+ }
17741749
1775- Ok ( Arc :: new ( result) as ArrayRef )
1750+ Ok ( Arc :: new ( UInt64Array :: from ( data) ) as ArrayRef )
1751+ } else {
1752+ Ok ( Arc :: new ( UInt64Array :: from ( vec ! [ 0 ; args[ 0 ] . len( ) ] ) ) as ArrayRef )
1753+ }
17761754}
17771755
17781756/// Array_has SQL function
@@ -2034,10 +2012,10 @@ mod tests {
20342012 . unwrap ( ) ;
20352013
20362014 let expected = as_list_array ( & array2d_1) . unwrap ( ) ;
2037- let expected_dim = compute_array_ndims ( Some ( array2d_1. to_owned ( ) ) ) . unwrap ( ) ;
2015+ let expected_dim = datafusion_common :: utils :: list_ndims ( array2d_1. data_type ( ) ) ;
20382016 assert_ne ! ( as_list_array( & res[ 0 ] ) . unwrap( ) , expected) ;
20392017 assert_eq ! (
2040- compute_array_ndims ( Some ( res[ 0 ] . clone ( ) ) ) . unwrap ( ) ,
2018+ datafusion_common :: utils :: list_ndims ( res[ 0 ] . data_type ( ) ) ,
20412019 expected_dim
20422020 ) ;
20432021
@@ -2047,10 +2025,10 @@ mod tests {
20472025 align_array_dimensions ( vec ! [ array1d_1, Arc :: new( array3d_2. clone( ) ) ] ) . unwrap ( ) ;
20482026
20492027 let expected = as_list_array ( & array3d_1) . unwrap ( ) ;
2050- let expected_dim = compute_array_ndims ( Some ( array3d_1. to_owned ( ) ) ) . unwrap ( ) ;
2028+ let expected_dim = datafusion_common :: utils :: list_ndims ( array3d_1. data_type ( ) ) ;
20512029 assert_ne ! ( as_list_array( & res[ 0 ] ) . unwrap( ) , expected) ;
20522030 assert_eq ! (
2053- compute_array_ndims ( Some ( res[ 0 ] . clone ( ) ) ) . unwrap ( ) ,
2031+ datafusion_common :: utils :: list_ndims ( res[ 0 ] . data_type ( ) ) ,
20542032 expected_dim
20552033 ) ;
20562034 }
0 commit comments