Skip to content

Commit 41663a9

Browse files
committed
Initial Implementation of array_intersect
Signed-off-by: veeupup <code@tanweime.com>
1 parent 4512805 commit 41663a9

File tree

11 files changed

+258
-14
lines changed

11 files changed

+258
-14
lines changed

datafusion/core/src/datasource/avro_to_arrow/arrow_array_reader.rs

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1536,12 +1536,10 @@ mod test {
15361536
.unwrap()
15371537
.resolve(&schema)
15381538
.unwrap();
1539-
let r4 = apache_avro::to_value(serde_json::json!({
1540-
"col1": null
1541-
}))
1542-
.unwrap()
1543-
.resolve(&schema)
1544-
.unwrap();
1539+
let r4 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1540+
.unwrap()
1541+
.resolve(&schema)
1542+
.unwrap();
15451543

15461544
let mut w = apache_avro::Writer::new(&schema, vec![]);
15471545
w.append(r1).unwrap();
@@ -1600,12 +1598,10 @@ mod test {
16001598
}"#,
16011599
)
16021600
.unwrap();
1603-
let r1 = apache_avro::to_value(serde_json::json!({
1604-
"col1": null
1605-
}))
1606-
.unwrap()
1607-
.resolve(&schema)
1608-
.unwrap();
1601+
let r1 = apache_avro::to_value(serde_json::json!({ "col1": null }))
1602+
.unwrap()
1603+
.resolve(&schema)
1604+
.unwrap();
16091605
let r2 = apache_avro::to_value(serde_json::json!({
16101606
"col1": {
16111607
"col2": "hello"

datafusion/expr/src/built_in_function.rs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ pub enum BuiltinScalarFunction {
174174
ArraySlice,
175175
/// array_to_string
176176
ArrayToString,
177+
/// array_intersect
178+
ArrayIntersect,
177179
/// cardinality
178180
Cardinality,
179181
/// construct an array from columns
@@ -398,6 +400,7 @@ impl BuiltinScalarFunction {
398400
BuiltinScalarFunction::Flatten => Volatility::Immutable,
399401
BuiltinScalarFunction::ArraySlice => Volatility::Immutable,
400402
BuiltinScalarFunction::ArrayToString => Volatility::Immutable,
403+
BuiltinScalarFunction::ArrayIntersect => Volatility::Immutable,
401404
BuiltinScalarFunction::Cardinality => Volatility::Immutable,
402405
BuiltinScalarFunction::MakeArray => Volatility::Immutable,
403406
BuiltinScalarFunction::Ascii => Volatility::Immutable,
@@ -577,6 +580,34 @@ impl BuiltinScalarFunction {
577580
BuiltinScalarFunction::ArrayReplaceAll => Ok(input_expr_types[0].clone()),
578581
BuiltinScalarFunction::ArraySlice => Ok(input_expr_types[0].clone()),
579582
BuiltinScalarFunction::ArrayToString => Ok(Utf8),
583+
BuiltinScalarFunction::ArrayIntersect => {
584+
if input_expr_types.len() < 2 || input_expr_types.len() > 2 {
585+
Err(DataFusionError::Internal(format!(
586+
"The {self} function must have two arrays as parameters"
587+
)))
588+
} else {
589+
match (&input_expr_types[0], &input_expr_types[1]) {
590+
(List(l_field), List(r_field)) => {
591+
if !l_field.data_type().equals_datatype(r_field.data_type()) {
592+
Err(DataFusionError::Internal(format!(
593+
"The {self} function array data type not equal, [0]: {:?}, [1]: {:?}",
594+
l_field.data_type(), r_field.data_type()
595+
)))
596+
} else {
597+
Ok(List(Arc::new(Field::new(
598+
"item",
599+
l_field.data_type().clone(),
600+
true,
601+
))))
602+
}
603+
}
604+
_ => Err(DataFusionError::Internal(format!(
605+
"The {} parameters should be array, [0]: {:?}, [1]: {:?}",
606+
self, input_expr_types[0], input_expr_types[1]
607+
))),
608+
}
609+
}
610+
}
580611
BuiltinScalarFunction::Cardinality => Ok(UInt64),
581612
BuiltinScalarFunction::MakeArray => match input_expr_types.len() {
582613
0 => Ok(List(Arc::new(Field::new("item", Null, true)))),
@@ -880,6 +911,7 @@ impl BuiltinScalarFunction {
880911
BuiltinScalarFunction::ArrayToString => {
881912
Signature::variadic_any(self.volatility())
882913
}
914+
BuiltinScalarFunction::ArrayIntersect => Signature::any(2, self.volatility()),
883915
BuiltinScalarFunction::Cardinality => Signature::any(1, self.volatility()),
884916
BuiltinScalarFunction::MakeArray => {
885917
// 0 or more arguments of arbitrary type
@@ -1505,6 +1537,7 @@ fn aliases(func: &BuiltinScalarFunction) -> &'static [&'static str] {
15051537
],
15061538
BuiltinScalarFunction::Cardinality => &["cardinality"],
15071539
BuiltinScalarFunction::MakeArray => &["make_array", "make_list"],
1540+
BuiltinScalarFunction::ArrayIntersect => &["array_intersect", "list_interact"],
15081541

15091542
// struct functions
15101543
BuiltinScalarFunction::Struct => &["struct"],

datafusion/expr/src/expr_fn.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -715,6 +715,12 @@ nary_scalar_expr!(
715715
array,
716716
"returns an Arrow array using the specified input expressions."
717717
);
718+
scalar_expr!(
719+
ArrayIntersect,
720+
array_intersect,
721+
first_array second_array,
722+
"Returns an array of the elements in the intersection of array1 and array2."
723+
);
718724

719725
// string functions
720726
scalar_expr!(Ascii, ascii, chr, "ASCII code value of the character");

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,10 @@ use std::any::type_name;
2121
use std::sync::Arc;
2222

2323
use arrow::array::*;
24-
use arrow::buffer::OffsetBuffer;
25-
use arrow::compute;
24+
use arrow::buffer::{Buffer, OffsetBuffer};
25+
use arrow::compute::{self, concat};
2626
use arrow::datatypes::{DataType, Field, UInt64Type};
27+
use arrow::row::{RowConverter, SortField};
2728
use arrow_buffer::NullBuffer;
2829

2930
use datafusion_common::cast::{
@@ -35,6 +36,9 @@ use datafusion_common::{
3536
DataFusionError, Result,
3637
};
3738

39+
use datafusion_common::ScalarValue;
40+
use datafusion_expr::ColumnarValue;
41+
use hashbrown::{HashMap, HashSet};
3842
use itertools::Itertools;
3943

4044
macro_rules! downcast_arg {
@@ -1807,6 +1811,63 @@ pub fn string_to_array<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef
18071811
Ok(Arc::new(list_array) as ArrayRef)
18081812
}
18091813

1814+
/// array_intersect SQL function
1815+
pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
1816+
assert_eq!(args.len(), 2);
1817+
1818+
let first_array = as_list_array(&args[0])?;
1819+
let second_array = as_list_array(&args[1])?;
1820+
1821+
if first_array.value_type() != second_array.value_type() {
1822+
return Err(DataFusionError::NotImplemented(format!(
1823+
"array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'",
1824+
)));
1825+
}
1826+
let dt = first_array.value_type().clone();
1827+
1828+
let mut offsets = vec![0];
1829+
let mut tmp_values = vec![];
1830+
1831+
let mut converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
1832+
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
1833+
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
1834+
let l_values = converter.convert_columns(&[first_arr])?;
1835+
let r_values = converter.convert_columns(&[second_arr])?;
1836+
1837+
let mut values_set: HashSet<_> = l_values.iter().collect();
1838+
let mut rows = Vec::with_capacity(r_values.num_rows());
1839+
for r_val in r_values.iter().sorted().dedup() {
1840+
if values_set.contains(&r_val) {
1841+
rows.push(r_val);
1842+
}
1843+
}
1844+
1845+
let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
1846+
DataFusionError::Internal(format!("offsets should not be empty"))
1847+
})?;
1848+
offsets.push(last_offset + rows.len() as i32);
1849+
let tmp_value = converter.convert_rows(rows)?;
1850+
tmp_values.push(
1851+
tmp_value
1852+
.get(0)
1853+
.ok_or_else(|| {
1854+
DataFusionError::Internal(format!(
1855+
"array_intersect: failed to get value from rows"
1856+
))
1857+
})?
1858+
.clone(),
1859+
);
1860+
}
1861+
}
1862+
1863+
let field = Arc::new(Field::new("item", dt, true));
1864+
let offsets = OffsetBuffer::new(offsets.into());
1865+
let tmp_values_ref = tmp_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
1866+
let values = concat(&tmp_values_ref)?;
1867+
let arr = Arc::new(ListArray::try_new(field, offsets, values, None)?);
1868+
Ok(arr)
1869+
}
1870+
18101871
#[cfg(test)]
18111872
mod tests {
18121873
use super::*;

datafusion/physical-expr/src/functions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -532,6 +532,9 @@ pub fn create_physical_fun(
532532
BuiltinScalarFunction::ArrayToString => Arc::new(|args| {
533533
make_scalar_function(array_expressions::array_to_string)(args)
534534
}),
535+
BuiltinScalarFunction::ArrayIntersect => Arc::new(|args| {
536+
make_scalar_function(array_expressions::array_intersect)(args)
537+
}),
535538
BuiltinScalarFunction::Cardinality => {
536539
Arc::new(|args| make_scalar_function(array_expressions::cardinality)(args))
537540
}

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,6 +621,7 @@ enum ScalarFunction {
621621
ArrayPopBack = 116;
622622
StringToArray = 117;
623623
ToTimestampNanos = 118;
624+
ArrayIntersect = 119;
624625
}
625626

626627
message ScalarFunctionNode {

datafusion/proto/src/generated/prost.rs

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
482482
ScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
483483
ScalarFunction::ArraySlice => Self::ArraySlice,
484484
ScalarFunction::ArrayToString => Self::ArrayToString,
485+
ScalarFunction::ArrayIntersect => Self::ArrayIntersect,
485486
ScalarFunction::Cardinality => Self::Cardinality,
486487
ScalarFunction::Array => Self::MakeArray,
487488
ScalarFunction::NullIf => Self::NullIf,

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1481,6 +1481,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
14811481
BuiltinScalarFunction::ArrayReplaceAll => Self::ArrayReplaceAll,
14821482
BuiltinScalarFunction::ArraySlice => Self::ArraySlice,
14831483
BuiltinScalarFunction::ArrayToString => Self::ArrayToString,
1484+
BuiltinScalarFunction::ArrayIntersect => Self::ArrayIntersect,
14841485
BuiltinScalarFunction::Cardinality => Self::Cardinality,
14851486
BuiltinScalarFunction::MakeArray => Self::Array,
14861487
BuiltinScalarFunction::NullIf => Self::NullIf,

0 commit comments

Comments
 (0)