Skip to content

Commit f5d10e5

Browse files
jayzhan211alamb
andauthored
Rewrite array_ndims to fix List(Null) handling (#8320)
* done Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more test Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> --------- Signed-off-by: jayzhan211 <jayzhan211@gmail.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent eb8aff7 commit f5d10e5

File tree

3 files changed

+97
-53
lines changed

3 files changed

+97
-53
lines changed

datafusion/common/src/utils.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use arrow::compute::{partition, SortColumn, SortOptions};
2626
use arrow::datatypes::{Field, SchemaRef, UInt32Type};
2727
use arrow::record_batch::RecordBatch;
2828
use arrow_array::{Array, LargeListArray, ListArray};
29+
use arrow_schema::DataType;
2930
use sqlparser::ast::Ident;
3031
use sqlparser::dialect::GenericDialect;
3132
use sqlparser::parser::Parser;
@@ -402,6 +403,37 @@ pub fn arrays_into_list_array(
402403
))
403404
}
404405

406+
/// Get the base type of a data type.
407+
///
408+
/// Example
409+
/// ```
410+
/// use arrow::datatypes::{DataType, Field};
411+
/// use datafusion_common::utils::base_type;
412+
/// use std::sync::Arc;
413+
///
414+
/// let data_type = DataType::List(Arc::new(Field::new("item", DataType::Int32, true)));
415+
/// assert_eq!(base_type(&data_type), DataType::Int32);
416+
///
417+
/// let data_type = DataType::Int32;
418+
/// assert_eq!(base_type(&data_type), DataType::Int32);
419+
/// ```
420+
pub fn base_type(data_type: &DataType) -> DataType {
421+
if let DataType::List(field) = data_type {
422+
base_type(field.data_type())
423+
} else {
424+
data_type.to_owned()
425+
}
426+
}
427+
428+
/// Compute the number of dimensions in a list data type.
429+
pub fn list_ndims(data_type: &DataType) -> u64 {
430+
if let DataType::List(field) = data_type {
431+
1 + list_ndims(field.data_type())
432+
} else {
433+
0
434+
}
435+
}
436+
405437
/// An extension trait for smart pointers. Provides an interface to get a
406438
/// raw pointer to the data (with metadata stripped away).
407439
///

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 27 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ use datafusion_common::cast::{
3333
as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
3434
as_null_array, as_string_array,
3535
};
36-
use datafusion_common::utils::array_into_list_array;
36+
use datafusion_common::utils::{array_into_list_array, list_ndims};
3737
use datafusion_common::{
3838
exec_err, internal_datafusion_err, internal_err, not_impl_err, plan_err,
3939
DataFusionError, Result,
@@ -103,6 +103,7 @@ fn compare_element_to_list(
103103
) -> Result<BooleanArray> {
104104
let indices = UInt32Array::from(vec![row_index as u32]);
105105
let element_array_row = arrow::compute::take(element_array, &indices, None)?;
106+
106107
// Compute all positions in list_row_array (that is itself an
107108
// array) that are equal to `from_array_row`
108109
let res = match element_array_row.data_type() {
@@ -176,35 +177,6 @@ fn compute_array_length(
176177
}
177178
}
178179

179-
/// Returns the dimension of the array
180-
fn compute_array_ndims(arr: Option<ArrayRef>) -> Result<Option<u64>> {
181-
Ok(compute_array_ndims_with_datatype(arr)?.0)
182-
}
183-
184-
/// Returns the dimension and the datatype of elements of the array
185-
fn compute_array_ndims_with_datatype(
186-
arr: Option<ArrayRef>,
187-
) -> Result<(Option<u64>, DataType)> {
188-
let mut res: u64 = 1;
189-
let mut value = match arr {
190-
Some(arr) => arr,
191-
None => return Ok((None, DataType::Null)),
192-
};
193-
if value.is_empty() {
194-
return Ok((None, DataType::Null));
195-
}
196-
197-
loop {
198-
match value.data_type() {
199-
DataType::List(..) => {
200-
value = downcast_arg!(value, ListArray).value(0);
201-
res += 1;
202-
}
203-
data_type => return Ok((Some(res), data_type.clone())),
204-
}
205-
}
206-
}
207-
208180
/// Returns the length of each array dimension
209181
fn compute_array_dims(arr: Option<ArrayRef>) -> Result<Option<Vec<Option<u64>>>> {
210182
let mut value = match arr {
@@ -825,10 +797,7 @@ pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
825797
fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
826798
let args_ndim = args
827799
.iter()
828-
.map(|arg| compute_array_ndims(Some(arg.to_owned())))
829-
.collect::<Result<Vec<_>>>()?
830-
.into_iter()
831-
.map(|x| x.unwrap_or(0))
800+
.map(|arg| datafusion_common::utils::list_ndims(arg.data_type()))
832801
.collect::<Vec<_>>();
833802
let max_ndim = args_ndim.iter().max().unwrap_or(&0);
834803

@@ -919,18 +888,19 @@ fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
919888
Arc::new(compute::concat(elements.as_slice())?),
920889
Some(NullBuffer::new(buffer)),
921890
);
891+
922892
Ok(Arc::new(list_arr))
923893
}
924894

925895
/// Array_concat/Array_cat SQL function
926896
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
927897
let mut new_args = vec![];
928898
for arg in args {
929-
let (ndim, lower_data_type) =
930-
compute_array_ndims_with_datatype(Some(arg.clone()))?;
931-
if ndim.is_none() || ndim == Some(1) {
932-
return not_impl_err!("Array is not type '{lower_data_type:?}'.");
933-
} else if !lower_data_type.equals_datatype(&DataType::Null) {
899+
let ndim = list_ndims(arg.data_type());
900+
let base_type = datafusion_common::utils::base_type(arg.data_type());
901+
if ndim == 0 {
902+
return not_impl_err!("Array is not type '{base_type:?}'.");
903+
} else if !base_type.eq(&DataType::Null) {
934904
new_args.push(arg.clone());
935905
}
936906
}
@@ -1765,14 +1735,22 @@ pub fn array_dims(args: &[ArrayRef]) -> Result<ArrayRef> {
17651735

17661736
/// Array_ndims SQL function
17671737
pub fn array_ndims(args: &[ArrayRef]) -> Result<ArrayRef> {
1768-
let list_array = as_list_array(&args[0])?;
1738+
if let Some(list_array) = args[0].as_list_opt::<i32>() {
1739+
let ndims = datafusion_common::utils::list_ndims(list_array.data_type());
17691740

1770-
let result = list_array
1771-
.iter()
1772-
.map(compute_array_ndims)
1773-
.collect::<Result<UInt64Array>>()?;
1741+
let mut data = vec![];
1742+
for arr in list_array.iter() {
1743+
if arr.is_some() {
1744+
data.push(Some(ndims))
1745+
} else {
1746+
data.push(None)
1747+
}
1748+
}
17741749

1775-
Ok(Arc::new(result) as ArrayRef)
1750+
Ok(Arc::new(UInt64Array::from(data)) as ArrayRef)
1751+
} else {
1752+
Ok(Arc::new(UInt64Array::from(vec![0; args[0].len()])) as ArrayRef)
1753+
}
17761754
}
17771755

17781756
/// Array_has SQL function
@@ -2034,10 +2012,10 @@ mod tests {
20342012
.unwrap();
20352013

20362014
let expected = as_list_array(&array2d_1).unwrap();
2037-
let expected_dim = compute_array_ndims(Some(array2d_1.to_owned())).unwrap();
2015+
let expected_dim = datafusion_common::utils::list_ndims(array2d_1.data_type());
20382016
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
20392017
assert_eq!(
2040-
compute_array_ndims(Some(res[0].clone())).unwrap(),
2018+
datafusion_common::utils::list_ndims(res[0].data_type()),
20412019
expected_dim
20422020
);
20432021

@@ -2047,10 +2025,10 @@ mod tests {
20472025
align_array_dimensions(vec![array1d_1, Arc::new(array3d_2.clone())]).unwrap();
20482026

20492027
let expected = as_list_array(&array3d_1).unwrap();
2050-
let expected_dim = compute_array_ndims(Some(array3d_1.to_owned())).unwrap();
2028+
let expected_dim = datafusion_common::utils::list_ndims(array3d_1.data_type());
20512029
assert_ne!(as_list_array(&res[0]).unwrap(), expected);
20522030
assert_eq!(
2053-
compute_array_ndims(Some(res[0].clone())).unwrap(),
2031+
datafusion_common::utils::list_ndims(res[0].data_type()),
20542032
expected_dim
20552033
);
20562034
}

datafusion/sqllogictest/test_files/array.slt

Lines changed: 38 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2479,10 +2479,44 @@ NULL [3] [4]
24792479
## array_ndims (aliases: `list_ndims`)
24802480

24812481
# array_ndims scalar function #1
2482+
24822483
query III
2483-
select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]]));
2484+
select
2485+
array_ndims(1),
2486+
array_ndims(null),
2487+
array_ndims([2, 3]);
24842488
----
2485-
1 2 5
2489+
0 0 1
2490+
2491+
statement ok
2492+
CREATE TABLE array_ndims_table
2493+
AS VALUES
2494+
(1, [1, 2, 3], [[7]], [[[[[10]]]]]),
2495+
(2, [4, 5], [[8]], [[[[[10]]]]]),
2496+
(null, [6], [[9]], [[[[[10]]]]]),
2497+
(3, [6], [[9]], [[[[[10]]]]])
2498+
;
2499+
2500+
query IIII
2501+
select
2502+
array_ndims(column1),
2503+
array_ndims(column2),
2504+
array_ndims(column3),
2505+
array_ndims(column4)
2506+
from array_ndims_table;
2507+
----
2508+
0 1 2 5
2509+
0 1 2 5
2510+
0 1 2 5
2511+
0 1 2 5
2512+
2513+
statement ok
2514+
drop table array_ndims_table;
2515+
2516+
query I
2517+
select array_ndims(arrow_cast([null], 'List(List(List(Int64)))'));
2518+
----
2519+
3
24862520

24872521
# array_ndims scalar function #2
24882522
query II
@@ -2494,7 +2528,7 @@ select array_ndims(array_repeat(array_repeat(array_repeat(1, 3), 2), 1)), array_
24942528
query II
24952529
select array_ndims(make_array()), array_ndims(make_array(make_array()))
24962530
----
2497-
NULL 2
2531+
1 2
24982532

24992533
# list_ndims scalar function #4 (function alias `array_ndims`)
25002534
query III
@@ -2505,7 +2539,7 @@ select list_ndims(make_array(1, 2, 3)), list_ndims(make_array([1, 2], [3, 4])),
25052539
query II
25062540
select array_ndims(make_array()), array_ndims(make_array(make_array()))
25072541
----
2508-
NULL 2
2542+
1 2
25092543

25102544
# array_ndims with columns
25112545
query III

0 commit comments

Comments
 (0)