From 8966dc005656e329dcf0cf6a47768240a4257c5e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=AD=E5=B7=8D?= Date: Sat, 11 Nov 2023 10:16:20 +0800 Subject: [PATCH] Implementation of `array_intersect` (#8081) * Initial Implementation of array_intersect Signed-off-by: veeupup * fix comments Signed-off-by: veeupup x --------- Signed-off-by: veeupup --- datafusion/expr/src/built_in_function.rs | 6 + datafusion/expr/src/expr_fn.rs | 6 + .../physical-expr/src/array_expressions.rs | 69 +++++++- datafusion/physical-expr/src/functions.rs | 3 + datafusion/proto/proto/datafusion.proto | 1 + datafusion/proto/src/generated/pbjson.rs | 3 + datafusion/proto/src/generated/prost.rs | 3 + .../proto/src/logical_plan/from_proto.rs | 1 + datafusion/proto/src/logical_plan/to_proto.rs | 1 + datafusion/sqllogictest/test_files/array.slt | 150 ++++++++++++++++++ docs/source/user-guide/expressions.md | 1 + 11 files changed, 238 insertions(+), 6 deletions(-) diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index f3f52e9dafb6..ca3ca18e4d77 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction { ArraySlice, /// array_to_string ArrayToString, + /// array_intersect + ArrayIntersect, /// cardinality Cardinality, /// construct an array from columns @@ -398,6 +400,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::Flatten => Volatility::Immutable, BuiltinScalarFunction::ArraySlice => Volatility::Immutable, BuiltinScalarFunction::ArrayToString => Volatility::Immutable, + BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable, BuiltinScalarFunction::Cardinality => Volatility::Immutable, BuiltinScalarFunction::MakeArray => Volatility::Immutable, BuiltinScalarFunction::Ascii => Volatility::Immutable, @@ -577,6 +580,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::ArrayToString => Ok(Utf8), + BuiltinScalarFunction::ArrayIntersect => Ok(input_expr_types[0].clone()), BuiltinScalarFunction::Cardinality => Ok(UInt64), BuiltinScalarFunction::MakeArray => match input_expr_types.len() { 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), @@ -880,6 +884,7 @@ impl BuiltinScalarFunction { BuiltinScalarFunction::ArrayToString => { Signature::variadic_any(self.volatility()) } + BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()), BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()), BuiltinScalarFunction::MakeArray => { // 0 or more arguments of arbitrary type @@ -1505,6 +1510,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] { ], BuiltinScalarFunction::Cardinality => &["cardinality"], BuiltinScalarFunction::MakeArray => &["make_array", "make_list"], + BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_intersect"], // struct functions BuiltinScalarFunction::Struct => &["struct"], diff --git a/datafusion/expr/src/expr_fn.rs b/datafusion/expr/src/expr_fn.rs index 5b1050020755..98cacc039228 100644 --- a/datafusion/expr/src/expr_fn.rs +++ b/datafusion/expr/src/expr_fn.rs @@ -715,6 +715,12 @@ nary_scalar_expr!( array, "returns an Arrow array using the specified input expressions." ); +scalar_expr!( + ArrayIntersect, + array_intersect, + first_array second_array, + "Returns an array of the elements in the intersection of array1 and array2." +); // string functions scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character"); diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 1bc25f56104b..87ba77b497b2 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -24,6 +24,7 @@ use arrow::array::*; use arrow::buffer::OffsetBuffer; use arrow::compute; use arrow::datatypes::{DataType, Field, UInt64Type}; +use arrow::row::{RowConverter, SortField}; use arrow_buffer::NullBuffer; use datafusion_common::cast::{ @@ -35,6 +36,7 @@ use datafusion_common::{ DataFusionError, Result, }; +use hashbrown::HashSet; use itertools::Itertools; macro_rules! downcast_arg { @@ -347,7 +349,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { let data_type = arrays[0].data_type(); let field = Arc::new(Field::new("item", data_type.to_owned(), true)); let elements = arrays.iter().map(|x| x.as_ref()).collect::>(); - let values = arrow::compute::concat(elements.as_slice())?; + let values = compute::concat(elements.as_slice())?; let list_arr = ListArray::new( field, OffsetBuffer::from_lengths(array_lengths), @@ -368,7 +370,7 @@ fn array_array(args: &[ArrayRef], data_type: DataType) -> Result { .iter() .map(|x| x as &dyn Array) .collect::>(); - let values = arrow::compute::concat(elements.as_slice())?; + let values = compute::concat(elements.as_slice())?; let list_arr = ListArray::new( field, OffsetBuffer::from_lengths(list_array_lengths), @@ -767,7 +769,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { .collect::>(); // Concatenated array on i-th row - let concated_array = arrow::compute::concat(elements.as_slice())?; + let concated_array = compute::concat(elements.as_slice())?; array_lengths.push(concated_array.len()); arrays.push(concated_array); valid.append(true); @@ -785,7 +787,7 @@ fn concat_internal(args: &[ArrayRef]) -> Result { let list_arr = ListArray::new( Arc::new(Field::new("item", data_type, true)), OffsetBuffer::from_lengths(array_lengths), - Arc::new(arrow::compute::concat(elements.as_slice())?), + Arc::new(compute::concat(elements.as_slice())?), Some(NullBuffer::new(buffer)), ); Ok(Arc::new(list_arr)) @@ -879,7 +881,7 @@ fn general_repeat(array: &ArrayRef, count_array: &Int64Array) -> Result = new_values.iter().map(|a| a.as_ref()).collect(); - let values = arrow::compute::concat(&new_values)?; + let values = compute::concat(&new_values)?; Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), @@ -947,7 +949,7 @@ fn general_list_repeat( let lengths = new_values.iter().map(|a| a.len()).collect::>(); let new_values: Vec<_> = new_values.iter().map(|a| a.as_ref()).collect(); - let values = arrow::compute::concat(&new_values)?; + let values = compute::concat(&new_values)?; Ok(Arc::new(ListArray::try_new( Arc::new(Field::new("item", data_type.to_owned(), true)), @@ -1798,6 +1800,61 @@ pub fn string_to_array(args: &[ArrayRef]) -> Result Result { + assert_eq!(args.len(), 2); + + let first_array = as_list_array(&args[0])?; + let second_array = as_list_array(&args[1])?; + + if first_array.value_type() != second_array.value_type() { + return internal_err!("array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'"); + } + let dt = first_array.value_type().clone(); + + let mut offsets = vec![0]; + let mut new_arrays = vec![]; + + let converter = RowConverter::new(vec![SortField::new(dt.clone())])?; + for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) { + if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) { + let l_values = converter.convert_columns(&[first_arr])?; + let r_values = converter.convert_columns(&[second_arr])?; + + let values_set: HashSet<_> = l_values.iter().collect(); + let mut rows = Vec::with_capacity(r_values.num_rows()); + for r_val in r_values.iter().sorted().dedup() { + if values_set.contains(&r_val) { + rows.push(r_val); + } + } + + let last_offset: i32 = match offsets.last().copied() { + Some(offset) => offset, + None => return internal_err!("offsets should not be empty"), + }; + offsets.push(last_offset + rows.len() as i32); + let arrays = converter.convert_rows(rows)?; + let array = match arrays.get(0) { + Some(array) => array.clone(), + None => { + return internal_err!( + "array_intersect: failed to get array from rows" + ) + } + }; + new_arrays.push(array); + } + } + + let field = Arc::new(Field::new("item", dt, true)); + let offsets = OffsetBuffer::new(offsets.into()); + let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::>(); + let values = compute::concat(&new_arrays_ref)?; + let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?); + Ok(arr) +} + #[cfg(test)] mod tests { use super::*; diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 088bac100978..c973232c75a6 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -398,6 +398,9 @@ pub fn create_physical_fun( BuiltinScalarFunction::ArrayToString => Arc::new(|args| { make_scalar_function(array_expressions::array_to_string)(args) }), + BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| { + make_scalar_function(array_expressions::array_intersect)(args) + }), BuiltinScalarFunction::Cardinality => { Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args)) } diff --git a/datafusion/proto/proto/datafusion.proto b/datafusion/proto/proto/datafusion.proto index bc6de2348e8d..f9deca2f1e52 100644 --- a/datafusion/proto/proto/datafusion.proto +++ b/datafusion/proto/proto/datafusion.proto @@ -621,6 +621,7 @@ enum ScalarFunction { ArrayPopBack = 116; StringToArray = 117; ToTimestampNanos = 118; + ArrayIntersect = 119; } message ScalarFunctionNode { diff --git a/datafusion/proto/src/generated/pbjson.rs b/datafusion/proto/src/generated/pbjson.rs index 659a25f9fa35..81f260c28bed 100644 --- a/datafusion/proto/src/generated/pbjson.rs +++ b/datafusion/proto/src/generated/pbjson.rs @@ -20730,6 +20730,7 @@ impl serde::Serialize for ScalarFunction { Self::ArrayPopBack => "ArrayPopBack", Self::StringToArray => "StringToArray", Self::ToTimestampNanos => "ToTimestampNanos", + Self::ArrayIntersect => "ArrayIntersect", }; serializer.serialize_str(variant) } @@ -20860,6 +20861,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack", "StringToArray", "ToTimestampNanos", + "ArrayIntersect", ]; struct GeneratedVisitor; @@ -21019,6 +21021,7 @@ impl<'de> serde::Deserialize<'de> for ScalarFunction { "ArrayPopBack" => Ok(ScalarFunction::ArrayPopBack), "StringToArray" => Ok(ScalarFunction::StringToArray), "ToTimestampNanos" => Ok(ScalarFunction::ToTimestampNanos), + "ArrayIntersect" => Ok(ScalarFunction::ArrayIntersect), _ => Err(serde::de::Error::unknown_variant(value, FIELDS)), } } diff --git a/datafusion/proto/src/generated/prost.rs b/datafusion/proto/src/generated/prost.rs index 75050e9d3dfa..ae64c11b3b74 100644 --- a/datafusion/proto/src/generated/prost.rs +++ b/datafusion/proto/src/generated/prost.rs @@ -2539,6 +2539,7 @@ pub enum ScalarFunction { ArrayPopBack = 116, StringToArray = 117, ToTimestampNanos = 118, + ArrayIntersect = 119, } impl ScalarFunction { /// String value of the enum field names used in the ProtoBuf definition. @@ -2666,6 +2667,7 @@ impl ScalarFunction { ScalarFunction::ArrayPopBack => "ArrayPopBack", ScalarFunction::StringToArray => "StringToArray", ScalarFunction::ToTimestampNanos => "ToTimestampNanos", + ScalarFunction::ArrayIntersect => "ArrayIntersect", } } /// Creates an enum from field names used in the ProtoBuf definition. @@ -2790,6 +2792,7 @@ impl ScalarFunction { "ArrayPopBack" => Some(Self::ArrayPopBack), "StringToArray" => Some(Self::StringToArray), "ToTimestampNanos" => Some(Self::ToTimestampNanos), + "ArrayIntersect" => Some(Self::ArrayIntersect), _ => None, } } diff --git a/datafusion/proto/src/logical_plan/from_proto.rs b/datafusion/proto/src/logical_plan/from_proto.rs index a3dcbc3fc80a..e5bcc934036d 100644 --- a/datafusion/proto/src/logical_plan/from_proto.rs +++ b/datafusion/proto/src/logical_plan/from_proto.rs @@ -483,6 +483,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction { ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, ScalarFunction::ArraySlice => Self::ArraySlice, ScalarFunction::ArrayToString => Self::ArrayToString, + ScalarFunction::ArrayIntersect => Self::ArrayIntersect, ScalarFunction::Cardinality => Self::Cardinality, ScalarFunction::Array => Self::MakeArray, ScalarFunction::NullIf => Self::NullIf, diff --git a/datafusion/proto/src/logical_plan/to_proto.rs b/datafusion/proto/src/logical_plan/to_proto.rs index 687b73cfc886..803becbcaece 100644 --- a/datafusion/proto/src/logical_plan/to_proto.rs +++ b/datafusion/proto/src/logical_plan/to_proto.rs @@ -1481,6 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction { BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll, BuiltinScalarFunction::ArraySlice => Self::ArraySlice, BuiltinScalarFunction::ArrayToString => Self::ArrayToString, + BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect, BuiltinScalarFunction::Cardinality => Self::Cardinality, BuiltinScalarFunction::MakeArray => Self::Array, BuiltinScalarFunction::NullIf => Self::NullIf, diff --git a/datafusion/sqllogictest/test_files/array.slt b/datafusion/sqllogictest/test_files/array.slt index c57369c167f4..f83ed5a95ff3 100644 --- a/datafusion/sqllogictest/test_files/array.slt +++ b/datafusion/sqllogictest/test_files/array.slt @@ -182,6 +182,55 @@ AS VALUES (make_array([[1], [2]], [[2], [3]]), make_array([1], [2])) ; +statement ok +CREATE TABLE array_intersect_table_1D +AS VALUES + (make_array(1, 2), make_array(1), make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3)), + (make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_Float +AS VALUES + (make_array(1.0, 2.0), make_array(1.0), make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)), + (make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_Boolean +AS VALUES + (make_array(true, true, true), make_array(false), make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)), + (make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true)) +; + +statement ok +CREATE TABLE array_intersect_table_1D_UTF8 +AS VALUES + (make_array('a', 'bc', 'def'), make_array('bc'), make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')), + (make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow')) +; + +statement ok +CREATE TABLE array_intersect_table_2D +AS VALUES + (make_array([1,2]), make_array([1,3]), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])), + (make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10])) +; + +statement ok +CREATE TABLE array_intersect_table_2D_float +AS VALUES + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])), + (make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3])) +; + +statement ok +CREATE TABLE array_intersect_table_3D +AS VALUES + (make_array([[1,2]]), make_array([[1]])), + (make_array([[1,2]]), make_array([[1,2]])) +; + statement ok CREATE TABLE arrays_values_without_nulls AS VALUES @@ -2316,6 +2365,86 @@ select array_has_all(make_array(1,2,3), make_array(1,3)), ---- true false true false false false true true false false true false true +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D; +---- +[1] [1, 3] [1, 3] +[11] [11, 33] [11, 33] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Float; +---- +[1.0] [1.0, 3.0] [] +[] [2.0] [1.11] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_Boolean; +---- +[] [false, true] [false] +[false] [true] [true] + +query ??? +select array_intersect(column1, column2), + array_intersect(column3, column4), + array_intersect(column5, column6) +from array_intersect_table_1D_UTF8; +---- +[bc] [arrow, rust] [] +[] [arrow, datafusion, rust] [arrow, rust] + +query ?? +select array_intersect(column1, column2), + array_intersect(column3, column4) +from array_intersect_table_2D; +---- +[] [[4, 5], [6, 7]] +[[3, 4]] [[5, 6, 7], [8, 9, 10]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_2D_float; +---- +[[1.1, 2.2], [3.3]] +[[1.1, 2.2], [3.3]] + +query ? +select array_intersect(column1, column2) +from array_intersect_table_3D; +---- +[] +[[[1, 2]]] + +query ?????? +SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), + array_intersect(make_array(1,3,5), make_array(2,4,6)), + array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + array_intersect(make_array(true, false), make_array(true)), + array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + +query ?????? +SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)), + list_intersect(make_array(1,3,5), make_array(2,4,6)), + list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')), + list_intersect(make_array(true, false), make_array(true)), + list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)), + list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4])) +; +---- +[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]] + query BBBB select list_has_all(make_array(1,2,3), make_array(4,5,6)), list_has_all(make_array(1,2,3), make_array(1,2)), @@ -2608,6 +2737,27 @@ drop table array_has_table_2D_float; statement ok drop table array_has_table_3D; +statement ok +drop table array_intersect_table_1D; + +statement ok +drop table array_intersect_table_1D_Float; + +statement ok +drop table array_intersect_table_1D_Boolean; + +statement ok +drop table array_intersect_table_1D_UTF8; + +statement ok +drop table array_intersect_table_2D; + +statement ok +drop table array_intersect_table_2D_float; + +statement ok +drop table array_intersect_table_3D; + statement ok drop table arrays_values_without_nulls; diff --git a/docs/source/user-guide/expressions.md b/docs/source/user-guide/expressions.md index dbe12df33564..27384dccffe0 100644 --- a/docs/source/user-guide/expressions.md +++ b/docs/source/user-guide/expressions.md @@ -232,6 +232,7 @@ Unlike to some databases the math functions in Datafusion works the same way as | array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` | | array_slice(array, index) | Returns a slice of the array. `array_slice([1, 2, 3, 4, 5, 6, 7, 8], 3, 6) -> [3, 4, 5, 6]` | | array_to_string(array, delimiter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` | +| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` | | cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` | | make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` | | trim_array(array, n) | Deprecated |