From 9ed1fcad9d0d338bab8bf7e0b29e656552df7abc Mon Sep 17 00:00:00 2001 From: veeupup Date: Sun, 19 Nov 2023 23:03:46 +0800 Subject: [PATCH] make array_intersect handle empty/null arrays rightly Signed-off-by: veeupup --- datafusion/expr/src/built_in_function.rs | 11 +- .../physical-expr/src/array_expressions.rs | 112 +++++++++++------- datafusion/sqllogictest/test_files/array.slt | 20 ++++ 3 files changed, 98 insertions(+), 45 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 7269d4d7119a0..62b43b343ef96 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -599,7 +599,16 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), - BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), + BuiltinScalarFunction::ArrayIntersect => { + match (input_expr_types[0].clone(), input_expr_types[1].clone()) { + (DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new( + Field::new("item", DataType::Null, true), + ))), + (DataType::Null, dt) => Ok(dt), + (dt, DataType::Null) => Ok(dt), + (dt, _) => Ok(dt), + } + } BuiltinScalarFunction::ArrayUnion => { match (input_expr_types[0].clone(), input_expr_types[1].clone()) { (DataType::Null, DataType::Null) => Ok(DataType::List(Arc::new( diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 439ad37bbf5cf..97f1548888884 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -582,7 +582,7 @@ pub fn array_except(args: &[ArrayRef]) -> Result { match (array1.data_type(), array2.data_type()) { (DataType::Null, DataType::Null) => { // NullArray(1): means null, NullArray(0): means [] - // except([], null) = [], except(null, []) = null, except(null, null) = null + // except([], []) = [], except([], null) = [], except(null, []) = null, except(null, null) = null let nulls = match (array1.len(), array2.len()) { (1, _) => Some(NullBuffer::new_null(1)), _ => None, @@ -1527,7 +1527,7 @@ pub fn array_union(args: &[ArrayRef]) -> Result { match (array1.data_type(), array2.data_type()) { (DataType::Null, DataType::Null) => { // NullArray(1): means null, NullArray(0): means [] - // union([], null) = [], union(null, []) = [], union(null, null) = null + // union([], []) = [], union([], null) = [], union(null, []) = [], union(null, null) = null let nulls = match (array1.len(), array2.len()) { (1, 1) => Some(NullBuffer::new_null(1)), _ => None, @@ -2028,55 +2028,79 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { assert_eq!(args.len(), 2); - let first_array = as_list_array(&args[0])?; - let second_array = as_list_array(&args[1])?; + let first_array = &args[0]; + let second_array = &args[1]; - if first_array.value_type() != second_array.value_type() { - return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); - } - let dt = first_array.value_type(); + match (first_array.data_type(), second_array.data_type()) { + (DataType::Null, DataType::Null) => { + // NullArray(1): means null, NullArray(0): means [] + // intersect([], []) = [], intersect([], null) = [], intersect(null, []) = [], intersect(null, null) = null + let nulls = match (first_array.len(), second_array.len()) { + (1, 1) => Some(NullBuffer::new_null(1)), + _ => None, + }; + let arr = Arc::new(ListArray::try_new( + Arc::new(Field::new("item", DataType::Null, true)), + OffsetBuffer::new(vec![0; 2].into()), + Arc::new(NullArray::new(0)), + nulls, + )?) as ArrayRef; + Ok(arr) + } + _ => { + let first_array = as_list_array(&first_array)?; + let second_array = as_list_array(&second_array)?; - let mut offsets = vec![0]; - let mut new_arrays = vec![]; - - let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; - for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { - if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { - let l_values = converter.convert_columns(&[first_arr])?; - let r_values = converter.convert_columns(&[second_arr])?; - - let values_set: HashSet<_> = l_values.iter().collect(); - let mut rows = Vec::with_capacity(r_values.num_rows()); - for r_val in r_values.iter().sorted().dedup() { - if values_set.contains(&r_val) { - rows.push(r_val); - } + if first_array.value_type() != second_array.value_type() { + return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); } - let last_offset: i32 = match offsets.last().copied() { - Some(offset) => offset, - None => return internal_err!("offsets should not be empty"), - }; - offsets.push(last_offset + rows.len() as i32); - let arrays = converter.convert_rows(rows)?; - let array = match arrays.get(0) { - Some(array) => array.clone(), - None => { - return internal_err!( - "array_intersect: failed to get array from rows" - ) + let dt = first_array.value_type(); + + let mut offsets = vec![0]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let values_set: HashSet<_> = l_values.iter().collect(); + let mut rows = Vec::with_capacity(r_values.num_rows()); + for r_val in r_values.iter().sorted().dedup() { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + + let last_offset: i32 = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + rows.len() as i32); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!( + "array_intersect: failed to get array from rows" + ) + } + }; + new_arrays.push(array); } - }; - new_arrays.push(array); + } + + let field = Arc::new(Field::new("item", dt, true)); + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = + new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); + Ok(arr) } } - - let field = Arc::new(Field::new("item", dt, true)); - let offsets = OffsetBuffer::new(offsets.into()); - let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); - let values = compute::concat(&new_arrays_ref)?; - let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); - Ok(arr) } #[cfg(test)] diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index db6d47056dfb4..6100a895c2312 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -2687,6 +2687,26 @@ SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), ---- [2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] +query ? +select array_intersect([], []); +---- +[] + +query ? +select array_intersect([], null); +---- +[] + +query ? +select array_intersect(null, []); +---- +[] + +query ? +select array_intersect(null, null); +---- +NULL + query ?????? SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), list_intersect(make_array(1,3,5), make_array(2,4,6)),