Skip to content

Commit cd02c40

Browse files
my-vegetable-has-explodedWeijun-Halamb
authored
Support array_distinct function. (#8268)
* implement distinct func implement slt & proto fix null & empty list * add comment for slt Co-authored-by: Alex Huang <huangweijun1001@gmail.com> * fix largelist * add largelist for slt * Use collect for rows & init capcity for offsets. * fixup: remove useless match * fix fmt * fix fmt --------- Co-authored-by: Alex Huang <huangweijun1001@gmail.com> Co-authored-by: Andrew Lamb <andrew@nerdnetworks.org>
1 parent 047fb33 commit cd02c40

File tree

11 files changed

+198
-11
lines changed

11 files changed

+198
-11
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,8 @@ pub enum BuiltinScalarFunction {
146146
ArrayPopBack,
147147
/// array_dims
148148
ArrayDims,
149+
/// array_distinct
150+
ArrayDistinct,
149151
/// array_element
150152
ArrayElement,
151153
/// array_empty
@@ -407,6 +409,7 @@ impl BuiltinScalarFunction {
407409
BuiltinScalarFunction::ArrayHasAny => Volatility::Immutable,
408410
BuiltinScalarFunction::ArrayHas => Volatility::Immutable,
409411
BuiltinScalarFunction::ArrayDims => Volatility::Immutable,
412+
BuiltinScalarFunction::ArrayDistinct => Volatility::Immutable,
410413
BuiltinScalarFunction::ArrayElement => Volatility::Immutable,
411414
BuiltinScalarFunction::ArrayExcept => Volatility::Immutable,
412415
BuiltinScalarFunction::ArrayLength => Volatility::Immutable,
@@ -586,6 +589,7 @@ impl BuiltinScalarFunction {
586589
BuiltinScalarFunction::ArrayDims => {
587590
Ok(List(Arc::new(Field::new("item", UInt64, true))))
588591
}
592+
BuiltinScalarFunction::ArrayDistinct => Ok(input_expr_types[0].clone()),
589593
BuiltinScalarFunction::ArrayElement => match &input_expr_types[0] {
590594
List(field) => Ok(field.data_type().clone()),
591595
_ => plan_err!(
@@ -933,6 +937,7 @@ impl BuiltinScalarFunction {
933937
Signature::variadic_any(self.volatility())
934938
}
935939
BuiltinScalarFunction::ArrayNdims => Signature::any(1, self.volatility()),
940+
BuiltinScalarFunction::ArrayDistinct => Signature::any(1, self.volatility()),
936941
BuiltinScalarFunction::ArrayPosition => {
937942
Signature::variadic_any(self.volatility())
938943
}
@@ -1570,6 +1575,7 @@ impl BuiltinScalarFunction {
15701575
&["array_concat", "array_cat", "list_concat", "list_cat"]
15711576
}
15721577
BuiltinScalarFunction::ArrayDims => &["array_dims", "list_dims"],
1578+
BuiltinScalarFunction::ArrayDistinct => &["array_distinct", "list_distinct"],
15731579
BuiltinScalarFunction::ArrayEmpty => &["empty"],
15741580
BuiltinScalarFunction::ArrayElement => &[
15751581
"array_element",

datafusion/expr/src/expr_fn.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -660,6 +660,12 @@ scalar_expr!(
660660
array,
661661
"returns the number of dimensions of the array."
662662
);
663+
scalar_expr!(
664+
ArrayDistinct,
665+
array_distinct,
666+
array,
667+
"return distinct values from the array after removing duplicates."
668+
);
663669
scalar_expr!(
664670
ArrayPosition,
665671
array_position,

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ use arrow_buffer::NullBuffer;
3131

3232
use arrow_schema::{FieldRef, SortOptions};
3333
use datafusion_common::cast::{
34-
as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
35-
as_null_array, as_string_array,
34+
as_generic_list_array, as_generic_string_array, as_int64_array, as_large_list_array,
35+
as_list_array, as_null_array, as_string_array,
3636
};
3737
use datafusion_common::utils::{array_into_list_array, list_ndims};
3838
use datafusion_common::{
@@ -2111,6 +2111,66 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
21112111
}
21122112
}
21132113

2114+
pub fn general_array_distinct<OffsetSize: OffsetSizeTrait>(
2115+
array: &GenericListArray<OffsetSize>,
2116+
field: &FieldRef,
2117+
) -> Result<ArrayRef> {
2118+
let dt = array.value_type();
2119+
let mut offsets = Vec::with_capacity(array.len());
2120+
offsets.push(OffsetSize::usize_as(0));
2121+
let mut new_arrays = Vec::with_capacity(array.len());
2122+
let converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
2123+
// distinct for each list in ListArray
2124+
for arr in array.iter().flatten() {
2125+
let values = converter.convert_columns(&[arr])?;
2126+
// sort elements in list and remove duplicates
2127+
let rows = values.iter().sorted().dedup().collect::<Vec<_>>();
2128+
let last_offset: OffsetSize = offsets.last().copied().unwrap();
2129+
offsets.push(last_offset + OffsetSize::usize_as(rows.len()));
2130+
let arrays = converter.convert_rows(rows)?;
2131+
let array = match arrays.get(0) {
2132+
Some(array) => array.clone(),
2133+
None => {
2134+
return internal_err!("array_distinct: failed to get array from rows")
2135+
}
2136+
};
2137+
new_arrays.push(array);
2138+
}
2139+
let offsets = OffsetBuffer::new(offsets.into());
2140+
let new_arrays_ref = new_arrays.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
2141+
let values = compute::concat(&new_arrays_ref)?;
2142+
Ok(Arc::new(GenericListArray::<OffsetSize>::try_new(
2143+
field.clone(),
2144+
offsets,
2145+
values,
2146+
None,
2147+
)?))
2148+
}
2149+
2150+
/// array_distinct SQL function
2151+
/// example: from list [1, 3, 2, 3, 1, 2, 4] to [1, 2, 3, 4]
2152+
pub fn array_distinct(args: &[ArrayRef]) -> Result<ArrayRef> {
2153+
assert_eq!(args.len(), 1);
2154+
2155+
// handle null
2156+
if args[0].data_type() == &DataType::Null {
2157+
return Ok(args[0].clone());
2158+
}
2159+
2160+
// handle for list & largelist
2161+
match args[0].data_type() {
2162+
DataType::List(field) => {
2163+
let array = as_list_array(&args[0])?;
2164+
general_array_distinct(array, field)
2165+
}
2166+
DataType::LargeList(field) => {
2167+
let array = as_large_list_array(&args[0])?;
2168+
general_array_distinct(array, field)
2169+
}
2170+
_ => internal_err!("array_distinct only support list array"),
2171+
}
2172+
}
2173+
21142174
#[cfg(test)]
21152175
mod tests {
21162176
use super::*;

datafusion/physical-expr/src/functions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,9 @@ pub fn create_physical_fun(
350350
BuiltinScalarFunction::ArrayDims => {
351351
Arc::new(|args| make_scalar_function(array_expressions::array_dims)(args))
352352
}
353+
BuiltinScalarFunction::ArrayDistinct => {
354+
Arc::new(|args| make_scalar_function(array_expressions::array_distinct)(args))
355+
}
353356
BuiltinScalarFunction::ArrayElement => {
354357
Arc::new(|args| make_scalar_function(array_expressions::array_element)(args))
355358
}

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -645,6 +645,7 @@ enum ScalarFunction {
645645
SubstrIndex = 126;
646646
FindInSet = 127;
647647
ArraySort = 128;
648+
ArrayDistinct = 129;
648649
}
649650

650651
message ScalarFunctionNode {

datafusion/proto/src/generated/pbjson.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,15 @@ use datafusion_common::{
4141
};
4242
use datafusion_expr::window_frame::{check_window_frame, regularize_window_order_by};
4343
use datafusion_expr::{
44-
abs, acos, acosh, array, array_append, array_concat, array_dims, array_element,
45-
array_except, array_has, array_has_all, array_has_any, array_intersect, array_length,
46-
array_ndims, array_position, array_positions, array_prepend, array_remove,
47-
array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all,
48-
array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin,
49-
asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil,
50-
character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot,
51-
current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest,
52-
encode, exp,
44+
abs, acos, acosh, array, array_append, array_concat, array_dims, array_distinct,
45+
array_element, array_except, array_has, array_has_all, array_has_any,
46+
array_intersect, array_length, array_ndims, array_position, array_positions,
47+
array_prepend, array_remove, array_remove_all, array_remove_n, array_repeat,
48+
array_replace, array_replace_all, array_replace_n, array_slice, array_sort,
49+
array_to_string, arrow_typeof, ascii, asin, asinh, atan, atan2, atanh, bit_length,
50+
btrim, cardinality, cbrt, ceil, character_length, chr, coalesce, concat_expr,
51+
concat_ws_expr, cos, cosh, cot, current_date, current_time, date_bin, date_part,
52+
date_trunc, decode, degrees, digest, encode, exp,
5353
expr::{self, InList, Sort, WindowFunction},
5454
factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero,
5555
lcm, left, levenshtein, ln, log, log10, log2,
@@ -484,6 +484,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
484484
ScalarFunction::ArrayHasAny => Self::ArrayHasAny,
485485
ScalarFunction::ArrayHas => Self::ArrayHas,
486486
ScalarFunction::ArrayDims => Self::ArrayDims,
487+
ScalarFunction::ArrayDistinct => Self::ArrayDistinct,
487488
ScalarFunction::ArrayElement => Self::ArrayElement,
488489
ScalarFunction::Flatten => Self::Flatten,
489490
ScalarFunction::ArrayLength => Self::ArrayLength,
@@ -1467,6 +1468,9 @@ pub fn parse_expr(
14671468
ScalarFunction::ArrayDims => {
14681469
Ok(array_dims(parse_expr(&args[0], registry)?))
14691470
}
1471+
ScalarFunction::ArrayDistinct => {
1472+
Ok(array_distinct(parse_expr(&args[0], registry)?))
1473+
}
14701474
ScalarFunction::ArrayElement => Ok(array_element(
14711475
parse_expr(&args[0], registry)?,
14721476
parse_expr(&args[1], registry)?,

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1512,6 +1512,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
15121512
BuiltinScalarFunction::ArrayHasAny => Self::ArrayHasAny,
15131513
BuiltinScalarFunction::ArrayHas => Self::ArrayHas,
15141514
BuiltinScalarFunction::ArrayDims => Self::ArrayDims,
1515+
BuiltinScalarFunction::ArrayDistinct => Self::ArrayDistinct,
15151516
BuiltinScalarFunction::ArrayElement => Self::ArrayElement,
15161517
BuiltinScalarFunction::Flatten => Self::Flatten,
15171518
BuiltinScalarFunction::ArrayLength => Self::ArrayLength,

datafusion/sqllogictest/test_files/array.slt

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,38 @@ AS VALUES
182182
(make_array([[1], [2]], [[2], [3]]), make_array([1], [2]))
183183
;
184184

185+
statement ok
186+
CREATE TABLE array_distinct_table_1D
187+
AS VALUES
188+
(make_array(1, 1, 2, 2, 3)),
189+
(make_array(1, 2, 3, 4, 5)),
190+
(make_array(3, 5, 3, 3, 3))
191+
;
192+
193+
statement ok
194+
CREATE TABLE array_distinct_table_1D_UTF8
195+
AS VALUES
196+
(make_array('a', 'a', 'bc', 'bc', 'def')),
197+
(make_array('a', 'bc', 'def', 'defg', 'defg')),
198+
(make_array('defg', 'defg', 'defg', 'defg', 'defg'))
199+
;
200+
201+
statement ok
202+
CREATE TABLE array_distinct_table_2D
203+
AS VALUES
204+
(make_array([1,2], [1,2], [3,4], [3,4], [5,6])),
205+
(make_array([1,2], [3,4], [5,6], [7,8], [9,10])),
206+
(make_array([5,6], [5,6], NULL))
207+
;
208+
209+
statement ok
210+
CREATE TABLE array_distinct_table_1D_large
211+
AS VALUES
212+
(arrow_cast(make_array(1, 1, 2, 2, 3), 'LargeList(Int64)')),
213+
(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)')),
214+
(arrow_cast(make_array(3, 5, 3, 3, 3), 'LargeList(Int64)'))
215+
;
216+
185217
statement ok
186218
CREATE TABLE array_intersect_table_1D
187219
AS VALUES
@@ -2864,6 +2896,73 @@ select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_ca
28642896
----
28652897
true false true false false false true true false false true false true
28662898

2899+
query BBBBBBBBBBBBB
2900+
select array_has_all(arrow_cast(make_array(1,2,3), 'LargeList(Int64)'), arrow_cast(make_array(1,3), 'LargeList(Int64)')),
2901+
array_has_all(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,4), 'LargeList(Int64)')),
2902+
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2]), 'LargeList(List(Int64))')),
2903+
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,3]), 'LargeList(List(Int64))')),
2904+
array_has_all(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,2], [3,4], [5,6]), 'LargeList(List(Int64))')),
2905+
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1]]), 'LargeList(List(List(Int64)))')),
2906+
array_has_all(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))')),
2907+
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(1,10,100), 'LargeList(Int64)')),
2908+
array_has_any(arrow_cast(make_array(1,2,3),'LargeList(Int64)'), arrow_cast(make_array(10,100),'LargeList(Int64)')),
2909+
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([1,10], [10,4]), 'LargeList(List(Int64))')),
2910+
array_has_any(arrow_cast(make_array([1,2], [3,4]), 'LargeList(List(Int64))'), arrow_cast(make_array([10,20], [3,4]), 'LargeList(List(Int64))')),
2911+
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3], [4,5,6]]), 'LargeList(List(List(Int64)))')),
2912+
array_has_any(arrow_cast(make_array([[1,2,3]]), 'LargeList(List(List(Int64)))'), arrow_cast(make_array([[1,2,3]], [[4,5,6]]), 'LargeList(List(List(Int64)))'))
2913+
;
2914+
----
2915+
true false true false false false true true false false true false true
2916+
2917+
## array_distinct
2918+
2919+
query ?
2920+
select array_distinct(null);
2921+
----
2922+
NULL
2923+
2924+
query ?
2925+
select array_distinct([]);
2926+
----
2927+
[]
2928+
2929+
query ?
2930+
select array_distinct([[], []]);
2931+
----
2932+
[[]]
2933+
2934+
query ?
2935+
select array_distinct(column1)
2936+
from array_distinct_table_1D;
2937+
----
2938+
[1, 2, 3]
2939+
[1, 2, 3, 4, 5]
2940+
[3, 5]
2941+
2942+
query ?
2943+
select array_distinct(column1)
2944+
from array_distinct_table_1D_UTF8;
2945+
----
2946+
[a, bc, def]
2947+
[a, bc, def, defg]
2948+
[defg]
2949+
2950+
query ?
2951+
select array_distinct(column1)
2952+
from array_distinct_table_2D;
2953+
----
2954+
[[1, 2], [3, 4], [5, 6]]
2955+
[[1, 2], [3, 4], [5, 6], [7, 8], [9, 10]]
2956+
[, [5, 6]]
2957+
2958+
query ?
2959+
select array_distinct(column1)
2960+
from array_distinct_table_1D_large;
2961+
----
2962+
[1, 2, 3]
2963+
[1, 2, 3, 4, 5]
2964+
[3, 5]
2965+
28672966
query ???
28682967
select array_intersect(column1, column2),
28692968
array_intersect(column3, column4),

0 commit comments

Comments
 (0)