Skip to content

Commit

Permalink
fix columns and add sqllogictest
Browse files Browse the repository at this point in the history
Signed-off-by: veeupup <code@tanweime.com>
  • Loading branch information
Veeupup committed Nov 8, 2023
1 parent 36b1cc9 commit 31dbcb1
Show file tree
Hide file tree
Showing 2 changed files with 167 additions and 68 deletions.
136 changes: 133 additions & 3 deletions datafusion/core/tests/sqllogictests/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,55 @@ AS VALUES
(make_array([[1], [2]], [[2], [3]]), make_array([1], [2]))
;

statement ok
CREATE TABLE array_intersect_table_1D
AS VALUES
(make_array(1, 2), make_array(1), make_array(1,2,3), make_array(1,3), make_array(1,3,5), make_array(2,4,6,8,1,3)),
(make_array(11, 22), make_array(11), make_array(11,22,33), make_array(11,33), make_array(11,33,55), make_array(22,44,66,88,11,33))
;

statement ok
CREATE TABLE array_intersect_table_1D_Float
AS VALUES
(make_array(1.0, 2.0), make_array(1.0), make_array(1.0,2.0,3.0), make_array(1.0,3.0), make_array(1.11), make_array(2.22, 3.33)),
(make_array(3.0, 4.0, 5.0), make_array(2.0), make_array(1.0,2.0,3.0,4.0), make_array(2.0,5.0), make_array(2.22, 1.11), make_array(1.11, 3.33))
;

statement ok
CREATE TABLE array_intersect_table_1D_Boolean
AS VALUES
(make_array(true, true, true), make_array(false), make_array(true, true, false, true, false), make_array(true, false, true), make_array(false), make_array(true, false)),
(make_array(false, false, false), make_array(false), make_array(true, false, true), make_array(true, true), make_array(true, true), make_array(false,false,true))
;

statement ok
CREATE TABLE array_intersect_table_1D_UTF8
AS VALUES
(make_array('a', 'bc', 'def'), make_array('bc'), make_array('datafusion', 'rust', 'arrow'), make_array('rust', 'arrow'), make_array('rust', 'arrow', 'python'), make_array('data')),
(make_array('a', 'bc', 'def'), make_array('defg'), make_array('datafusion', 'rust', 'arrow'), make_array('datafusion', 'rust', 'arrow', 'python'), make_array('rust', 'arrow'), make_array('datafusion', 'rust', 'arrow'))
;

statement ok
CREATE TABLE array_intersect_table_2D
AS VALUES
(make_array([1,2]), make_array([1,3]), make_array([1,2,3], [4,5], [6,7]), make_array([4,5], [6,7])),
(make_array([3,4], [5]), make_array([3,4]), make_array([1,2,3,4], [5,6,7], [8,9,10]), make_array([1,2,3], [5,6,7], [8,9,10]))
;

statement ok
CREATE TABLE array_intersect_table_2D_float
AS VALUES
(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.1, 2.2], [3.3])),
(make_array([1.0, 2.0, 3.0], [1.1, 2.2], [3.3]), make_array([1.0], [1.1, 2.2], [3.3]))
;

statement ok
CREATE TABLE array_intersect_table_3D
AS VALUES
(make_array([[1,2]]), make_array([[1]])),
(make_array([[1,2]]), make_array([[1,2]]))
;

statement ok
CREATE TABLE arrays_values_without_nulls
AS VALUES
Expand Down Expand Up @@ -1695,14 +1744,74 @@ select array_has_all(make_array(1,2,3), make_array(1,3)),
----
true false true false false false true true false false true false true

query ????
query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
array_intersect(column5, column6)
from array_intersect_table_1D;
----
[1] [1, 3] [1, 3]
[11] [11, 33] [11, 33]

query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
array_intersect(column5, column6)
from array_intersect_table_1D_Float;
----
[1.0] [1.0, 3.0] []
[] [2.0] [1.11]

query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
array_intersect(column5, column6)
from array_intersect_table_1D_Boolean;
----
[] [false, true] [false]
[false] [true] [true]

query ???
select array_intersect(column1, column2),
array_intersect(column3, column4),
array_intersect(column5, column6)
from array_intersect_table_1D_UTF8;
----
[bc] [arrow, rust] []
[] [arrow, datafusion, rust] [arrow, rust]

query ??
select array_intersect(column1, column2),
array_intersect(column3, column4)
from array_intersect_table_2D;
----
[] [[4, 5], [6, 7]]
[[3, 4]] [[5, 6, 7], [8, 9, 10]]

query ?
select array_intersect(column1, column2)
from array_intersect_table_2D_float;
----
[[1.1, 2.2], [3.3]]
[[1.1, 2.2], [3.3]]

query ?
select array_intersect(column1, column2)
from array_intersect_table_3D;
----
[]
[[[1, 2]]]

query ??????
SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)),
array_intersect(make_array(1,3,5), make_array(2,4,6)),
array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
array_intersect(make_array(true, false), make_array(true))
array_intersect(make_array(true, false), make_array(true)),
array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)),
array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4]))
;
----
[2, 3] [] [cc, aa] [true]
[2, 3] [] [aa, cc] [true] [2.2, 3.3] [[2, 2], [3, 3]]

query BBBB
select list_has_all(make_array(1,2,3), make_array(4,5,6)),
Expand Down Expand Up @@ -1843,6 +1952,27 @@ drop table array_has_table_2D_float;
statement ok
drop table array_has_table_3D;

statement ok
drop table array_intersect_table_1D;

statement ok
drop table array_intersect_table_1D_Float;

statement ok
drop table array_intersect_table_1D_Boolean;

statement ok
drop table array_intersect_table_1D_UTF8;

statement ok
drop table array_intersect_table_2D;

statement ok
drop table array_intersect_table_2D_float;

statement ok
drop table array_intersect_table_3D;

statement ok
drop table arrays_values_without_nulls;

Expand Down
99 changes: 34 additions & 65 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1829,83 +1829,52 @@ pub fn array_intersect(args: &[ArrayRef]) -> Result<ArrayRef> {
let first_array = as_list_array(&args[0])?;
let second_array = as_list_array(&args[1])?;

let dt = match (first_array.value_type(), second_array.value_type()) {
// (DataType::List(_), DataType::List(_)) => concat_internal(args)?,
(DataType::Utf8, DataType::Utf8) => DataType::Utf8,
(DataType::LargeUtf8, DataType::LargeUtf8) => DataType::LargeUtf8,
(DataType::Boolean, DataType::Boolean) => DataType::Boolean,
(DataType::Float32, DataType::Float32) => DataType::Float32,
(DataType::Float64, DataType::Float64) => DataType::Float64,
(DataType::Int8, DataType::Int8) => DataType::Int8,
(DataType::Int16, DataType::Int16) => DataType::Int16,
(DataType::Int32, DataType::Int32) => DataType::Int32,
(DataType::Int64, DataType::Int64) => DataType::Int64,
(DataType::UInt8, DataType::UInt8) => DataType::UInt8,
(DataType::UInt16, DataType::UInt16) => DataType::UInt16,
(DataType::UInt32, DataType::UInt32) => DataType::UInt32,
(DataType::UInt64, DataType::UInt64) => DataType::UInt64,
// (DataType::Null, _) => return Ok(array(&[ColumnarValue::Array(args[1].clone())])?.into_array(1)),
(first_value_dt, second_value_dt) =>
return Err(DataFusionError::NotImplemented(format!(
"array_intersect is not implemented for '{first_value_dt:?}' and '{second_value_dt:?}'",
)))
};
if first_array.value_type() != second_array.value_type() {
return Err(DataFusionError::NotImplemented(format!(
"array_intersect is not implemented for '{first_array:?}' and '{second_array:?}'",
)));
}
let dt = first_array.value_type().clone();

let mut offsets = vec![0];

let mut tmp_values = vec![];

let mut converter = RowConverter::new(vec![SortField::new(dt.clone())])?;
for (first_arr, second_arr) in first_array.iter().zip(second_array.iter()) {
match (first_arr, second_arr) {
(Some(first_arr), Some(second_arr)) => {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

let mut values_set = HashSet::new();

for (l_w, r_w) in first_array
.offsets()
.windows(2)
.zip(second_array.offsets().windows(2))
{
let l_slice = l_w[0]..l_w[1];
let r_slice = r_w[0]..r_w[1];

l_slice.for_each(|i| {
values_set.insert(l_values.row(i as usize));
});

let mut rows = vec![];
for i in r_slice {
let idx = i as usize;
if values_set.contains(&r_values.row(idx)) {
rows.push(r_values.row(idx));
}
}
if let (Some(first_arr), Some(second_arr)) = (first_arr, second_arr) {
let l_values = converter.convert_columns(&[first_arr])?;
let r_values = converter.convert_columns(&[second_arr])?;

offsets.push(rows.len() as i32);
let tmp_value = converter.convert_rows(rows)?;
tmp_values.push(
tmp_value
.get(0)
.ok_or_else(|| {
DataFusionError::Internal(format!(
"array_intersect: failed to get value from rows"
))
})?
.clone(),
);
values_set.clear();
}
let mut values_set = HashSet::with_capacity(l_values.num_rows());
for l_val in l_values.iter() {
values_set.insert(l_val);
}
_ => {
todo!()
let mut rows = Vec::with_capacity(r_values.num_rows());
for r_val in r_values.iter().sorted().dedup() {
if values_set.contains(&r_val) {
rows.push(r_val);
}
}

let last_offset: i32 = offsets.last().copied().ok_or_else(|| {
DataFusionError::Internal(format!("offsets should not be empty"))
})?;
offsets.push(last_offset + rows.len() as i32);
let tmp_value = converter.convert_rows(rows)?;
tmp_values.push(
tmp_value
.get(0)
.ok_or_else(|| {
DataFusionError::Internal(format!(
"array_intersect: failed to get value from rows"
))
})?
.clone(),
);
}
}

let field = Arc::new(Field::new("item_list", dt, true));
let field = Arc::new(Field::new("item", dt, true));
let offsets = OffsetBuffer::new(offsets.into());
let tmp_values_ref = tmp_values.iter().map(|v| v.as_ref()).collect::<Vec<_>>();
let values = concat(&tmp_values_ref)?;
Expand Down

0 comments on commit 31dbcb1

Please sign in to comment.