Skip to content

Commit

Permalink
Initial Implementation of array_intersect
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <code@tanweime.com>
  • Loading branch information
Veeupup committed Nov 7, 2023
1 parent 01d7dba commit 9b050db
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 0 deletions.
5 changes: 5 additions & 0 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -1695,6 +1695,11 @@ select array_has_all(make_array(1,2,3), make_array(1,3)),
----
true false true false false false true true false false true false true

query ??
SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)), array_intersect(make_array(1,3,5), make_array(2,4,6));
----
[2, 3] []

query BBBB
select list_has_all(make_array(1,2,3), make_array(4,5,6)),
list_has_all(make_array(1,2,3), make_array(1,2)),
Expand Down
33 changes: 33 additions & 0 deletions datafusion/expr/src/built_in_function.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,8 @@ pub enum BuiltinScalarFunction {
ArrayReplaceAll,
/// array_to_string
ArrayToString,
/// array_intersect
ArrayIntersect,
/// cardinality
Cardinality,
/// construct an array from columns
Expand Down Expand Up @@ -359,6 +361,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceN => Volatility::Immutable,
BuiltinScalarFunction::ArrayReplaceAll => Volatility::Immutable,
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
BuiltinScalarFunction::TrimArray => Volatility::Immutable,
Expand Down Expand Up @@ -543,6 +546,34 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayReplaceN => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
BuiltinScalarFunction::ArrayIntersect => {
if input_expr_types.len() < 2 || input_expr_types.len() > 2 {
Err(DataFusionError::Internal(format!(
"The {self} function must have two arrays as parameters"
)))
} else {
match (&input_expr_types[0], &input_expr_types[1]) {
(List(l_field), List(r_field)) => {
if !l_field.data_type().equals_datatype(r_field.data_type()) {
Err(DataFusionError::Internal(format!(
"The {self} function array data type not equal, [0]: {:?}, [1]: {:?}",
l_field.data_type(), r_field.data_type()
)))
} else {
Ok(List(Arc::new(Field::new(
"item",
l_field.data_type().clone(),
true,
))))
}
}
_ => Err(DataFusionError::Internal(format!(
"The {} parameters should be array, [0]: {:?}, [1]: {:?}",
self, input_expr_types[0], input_expr_types[1]
))),
}
}
}
BuiltinScalarFunction::Cardinality => Ok(UInt64),
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
Expand Down Expand Up @@ -834,6 +865,7 @@ impl BuiltinScalarFunction {
BuiltinScalarFunction::ArrayToString => {
Signature::variadic_any(self.volatility())
}
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
BuiltinScalarFunction::MakeArray => {
Signature::variadic_any(self.volatility())
Expand Down Expand Up @@ -1324,6 +1356,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
BuiltinScalarFunction::Cardinality => &["cardinality"],
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
BuiltinScalarFunction::TrimArray => &["trim_array"],
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_interact"],
}
}

Expand Down
6 changes: 6 additions & 0 deletions datafusion/expr/src/expr_fn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,12 @@ scalar_expr!(
array n,
"removes the last n elements from the array."
);
scalar_expr!(
ArrayIntersect,
array_intersect,
first_array second_array,
"Returns an array of the elements in the intersection of array1 and array2."
);

// string functions
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");
Expand Down
206 changes: 206 additions & 0 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_a
use datafusion_common::ScalarValue;
use datafusion_common::{DataFusionError, Result};
use datafusion_expr::ColumnarValue;
use hashbrown::{HashMap, HashSet};
use itertools::Itertools;
use std::sync::Arc;

Expand Down Expand Up @@ -1820,6 +1821,211 @@ pub fn array_has_all(args: &[ArrayRef]) -> Result<ArrayRef> {
Ok(Arc::new(boolean_builder.finish()))
}

macro_rules! array_intersect_normal {
($FIRST_ARRAY:expr, $SECOND_ARRAY:expr, $DATA_TYPE:expr, $ARRAY_TYPE:ident) => {{
let mut offsets: Vec<i32> = vec![0];
let mut values =
downcast_arg!(new_empty_array(&$DATA_TYPE), $ARRAY_TYPE).clone();

for (first_arr, second_arr) in $FIRST_ARRAY.iter().zip($SECOND_ARRAY.iter()) {
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty"))
})?;
match (first_arr, second_arr) {
(Some(first_arr), Some(second_arr)) => {
let first_arr = downcast_arg!(first_arr, $ARRAY_TYPE);
// TODO(veeupup): maybe use stack-implemented map to avoid heap memory allocation
let first_set = first_arr.iter().dedup().flatten().collect::<HashSet<_>>();
let second_arr = downcast_arg!(second_arr, $ARRAY_TYPE);

let mut builder = $ARRAY_TYPE::builder(first_arr.len().min(second_arr.len()));
for elem in second_arr.iter().dedup().flatten() {
if first_set.contains(&elem) {
builder.append_value(elem);
}
}

let arr = builder.finish();
values = downcast_arg!(
compute::concat(&[
&values,
&arr
])?
.clone(),
$ARRAY_TYPE
)
.clone();
offsets.push(last_offset + arr.len() as i32);
},
_ => {
todo!()
}
}
}
let field = Arc::new(Field::new("item", $DATA_TYPE, true));

Ok(Arc::new(ListArray::try_new(
field,
OffsetBuffer::new(offsets.into()),
Arc::new(values),
None,
)?))

}};
}

/// array_intersect SQL function
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
assert_eq!(args.len(), 2);

let first_array = as_list_array(&args[0])?;
let second_array = as_list_array(&args[1])?;

// write array interact method

match (first_array.value_type(), second_array.value_type()) {
// (DataType::List(_), DataType::List(_)) => concat_internal(args)?,
// (DataType::Utf8, DataType::Utf8) => array_intersect_normal!(arr, element, StringArray),
// (DataType::LargeUtf8, DataType::LargeUtf8) => array_intersect_normal!(arr, element, LargeStringArray),
// (DataType::Boolean, DataType::Boolean) => array_intersect_normal!(arr, element, BooleanArray),
// (DataType::Float32, DataType::Float32) => array_intersect_normal!(arr, element, Float32Array),
// (DataType::Float64, DataType::Float64) => array_intersect_normal!(arr, element, Float64Array),
(DataType::Int8, DataType::Int8) => array_intersect_normal!(first_array, second_array, DataType::Int8, Int8Array),
(DataType::Int16, DataType::Int16) => array_intersect_normal!(first_array, second_array, DataType::Int16, Int16Array),
(DataType::Int32, DataType::Int32) => array_intersect_normal!(first_array, second_array, DataType::Int32, Int32Array),
(DataType::Int64, DataType::Int64) => array_intersect_normal!(first_array, second_array, DataType::Int64, Int64Array),
(DataType::UInt8, DataType::UInt8) => array_intersect_normal!(first_array, second_array, DataType::UInt8, UInt8Array),
(DataType::UInt16, DataType::UInt16) => array_intersect_normal!(first_array, second_array, DataType::UInt16, UInt16Array),
(DataType::UInt32, DataType::UInt32) => array_intersect_normal!(first_array, second_array, DataType::UInt32, UInt32Array),
(DataType::UInt64, DataType::UInt64) => array_intersect_normal!(first_array, second_array, DataType::UInt64, UInt64Array),
// (DataType::Null, _) => return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)),
(DataType::Int64, DataType::Int64) => {
let mut offsets: Vec<i32> = vec![0];
let mut values =
downcast_arg!(new_empty_array(&DataType::Int64), Int64Array).clone();

for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty"))
})?;
match (first_arr, second_arr) {
(Some(first_arr), Some(second_arr)) => {
let first_arr = downcast_arg!(first_arr, Int64Array);
let first_set = first_arr.iter().dedup().flatten().collect::<HashSet<_>>();
println!("{:?}", first_set);
let second_arr = downcast_arg!(second_arr, Int64Array);
print!("{:?}", second_arr);

let mut builder = Int64Array::builder(first_arr.len().min(second_arr.len()));
for elem in second_arr.iter().dedup().flatten() {
println!("second_arr: {:?}", elem);
if first_set.contains(&elem) {
builder.append_value(elem);
}
}

let arr = builder.finish();
values = downcast_arg!(
compute::concat(&[
&values,
&arr
])?
.clone(),
Int64Array
)
.clone();
offsets.push(last_offset + arr.len() as i32);
},
_ => {
todo!()
}
}
}
let field = Arc::new(Field::new("item", DataType::Int64, true));

Ok(Arc::new(ListArray::try_new(
field,
OffsetBuffer::new(offsets.into()),
Arc::new(values),
None,
)?))
},
(first_value_dt, second_value_dt) => {
Err(DataFusionError::NotImplemented(format!(
"array_intersect is not implemented for '{first_value_dt:?}' and '{second_value_dt:?}'",
)))
}
}
// for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
// if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
// let ret: ArrayRef = match (first_arr.data_type(), second_arr.data_type()) {
// (DataType::List(_), DataType::List(_)) => {
// todo!()
// },
// // Int64, Int32, Int16, Int8
// // UInt64, UInt32, UInt16, UInt8
// (DataType::Int64, DataType::Int64) => {
// // array_intersect_non_list_check!(first_arr, second_arr, Int64Array)
// let first_arr = downcast_arg!(first_array, Int64Array);
// let second_arr = downcast_arg!(second_array, Int64Array);

// let mut offsets: Vec<i32> = vec![0];
// let mut values =
// Int64Array::builder(first_arr.len().min(second_arr.len()));
// let first_set = first_arr.iter().dedup().flatten().collect::<HashSet<_>>();
// for elem in second_arr.iter().dedup().flatten() {
// if first_set.contains(&elem) {
// let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
// DataFusionError::Internal(format!("offsets should not be empty"))
// })?;
// values.append_value(elem);
// offsets.push(last_offset + 1);
// }
// }

// let field = Arc::new(Field::new("item", DataType::Int64, true));

// Arc::new(ListArray::try_new(
// field,
// OffsetBuffer::new(offsets.into()),
// Arc::new(values.finish()),
// None,
// )?)
// },
// (DataType::Int32, DataType::Int32) => {
// array_intersect_non_list_check!(first_arr, second_arr, Int32Array)
// }
// (DataType::Int16, DataType::Int16) => {
// array_intersect_non_list_check!(first_arr, second_arr, Int16Array)
// }
// (DataType::Int8, DataType::Int8) => {
// array_intersect_non_list_check!(first_arr, second_arr, Int8Array)
// }
// (DataType::UInt64, DataType::UInt64) => {
// array_intersect_non_list_check!(first_arr, second_arr, UInt64Array)
// }
// (DataType::UInt32, DataType::UInt32) => {
// array_intersect_non_list_check!(first_arr, second_arr, UInt32Array)
// }
// (DataType::UInt16, DataType::UInt16) => {
// array_intersect_non_list_check!(first_arr, second_arr, UInt16Array)
// }
// (DataType::UInt8, DataType::UInt8) => {
// array_intersect_non_list_check!(first_arr, second_arr, UInt8Array)
// }

// (first_arr_type, second_arr_type) => Err(DataFusionError::NotImplemented(format!(
// "array_intersect is not implemented for '{first_arr_type:?}' and '{second_arr_type:?}'",
// )))?,
// };
// return Ok(ret);
// }
// }
// Err(DataFusionError::Internal(format!(
// "array_intersect does not support Null type for element in sub_array"
// )))
}

#[cfg(test)]
mod tests {
use super::*;
Expand Down
3 changes: 3 additions & 0 deletions datafusion/physical-expr/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -465,6 +465,9 @@ pub fn create_physical_fun(
BuiltinScalarFunction::ArrayToString => Arc::new(|args| {
make_scalar_function(array_expressions::array_to_string)(args)
}),
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
make_scalar_function(array_expressions::array_intersect)(args)
}),
BuiltinScalarFunction::Cardinality => {
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
}
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/proto/datafusion.proto
Original file line number Diff line number Diff line change
Expand Up @@ -577,6 +577,7 @@ enum ScalarFunction {
ArrayReplaceN = 108;
ArrayRemoveAll = 109;
ArrayReplaceAll = 110;
ArrayIntersect = 111;
}

message ScalarFunctionNode {
Expand Down
2 changes: 2 additions & 0 deletions datafusion/proto/src/generated/prost.rs

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/from_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
ScalarFunction::ArrayReplaceN => Self::ArrayReplaceN,
ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
ScalarFunction::ArrayToString => Self::ArrayToString,
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
ScalarFunction::Cardinality => Self::Cardinality,
ScalarFunction::Array => Self::MakeArray,
ScalarFunction::TrimArray => Self::TrimArray,
Expand Down
1 change: 1 addition & 0 deletions datafusion/proto/src/logical_plan/to_proto.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1417,6 +1417,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
BuiltinScalarFunction::ArrayReplaceN => Self::ArrayReplaceN,
BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
BuiltinScalarFunction::Cardinality => Self::Cardinality,
BuiltinScalarFunction::MakeArray => Self::Array,
BuiltinScalarFunction::TrimArray => Self::TrimArray,
Expand Down
1 change: 1 addition & 0 deletions docs/source/user-guide/expressions.md
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,7 @@ Unlike to some databases the math functions in Datafusion works the same way as
| array_replace_n(array, from, to, max) | Replaces the first `max` occurrences of the specified element with another specified element. `array_replace_n([1, 2, 2, 3, 2, 1, 4], 2, 5, 2) -> [1, 5, 5, 3, 2, 1, 4]` |
| array_replace_all(array, from, to) | Replaces all occurrences of the specified element with another specified element. `array_replace_all([1, 2, 2, 3, 2, 1, 4], 2, 5) -> [1, 5, 5, 3, 5, 1, 4]` |
| array_to_string(array, delimeter) | Converts each element to its text representation. `array_to_string([1, 2, 3, 4], ',') -> 1,2,3,4` |
| array_intersect(array1, array2) | Returns an array of the elements in the intersection of array1 and array2. `array_intersect([1, 2, 3, 4], [5, 6, 3, 4]) -> [3, 4]` |
| cardinality(array) | Returns the total number of elements in the array. `cardinality([[1, 2, 3], [4, 5, 6]]) -> 6` |
| make_array(value1, [value2 [, ...]]) | Returns an Arrow array using the specified input expressions. `make_array(1, 2, 3) -> [1, 2, 3]` |
| trim_array(array, n) | Removes the last n elements from the array. |
Expand Down

0 comments on commit 9b050db

Please sign in to comment.