Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support LargeList in array_position #8714

Merged
merged 1 commit into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1367,8 +1367,14 @@ pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("array_position expects two or three arguments");
}

let list_array = as_list_array(&args[0])?;
match &args[0].data_type() {
DataType::List(_) => general_position_dispatch::<i32>(args),
DataType::LargeList(_) => general_position_dispatch::<i64>(args),
array_type => exec_err!("array_position does not support type '{array_type:?}'."),
}
}
fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_generic_list_array::<O>(&args[0])?;
let element_array = &args[1];

check_datatypes("array_position", &[list_array.values(), element_array])?;
Expand All @@ -1395,10 +1401,10 @@ pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

general_position::<i32>(list_array, element_array, arr_from)
generic_position::<O>(list_array, element_array, arr_from)
}

fn general_position<OffsetSize: OffsetSizeTrait>(
fn generic_position<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
arr_from: Vec<i64>, // 0-indexed
Expand Down
71 changes: 71 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,17 @@ AS VALUES
(make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9])
;

statement ok
CREATE TABLE large_arrays_values_without_nulls
AS SELECT
arrow_cast(column1, 'LargeList(Int64)') AS column1,
column2,
column3,
column4,
arrow_cast(column5, 'LargeList(Int64)') AS column5
FROM arrays_values_without_nulls
;

statement ok
CREATE TABLE arrays_range
AS VALUES
Expand Down Expand Up @@ -2054,12 +2065,22 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3,
----
3 5 1

query III
select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# array_position scalar function #2 (with optional argument)
query III
select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2);
----
4 5 2

query III
select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l', 4), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5, 4), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1, 2);
----
4 5 2

# array_position scalar function #3 (element is list)
query II
select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]);
Expand All @@ -2072,24 +2093,44 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7,
----
4 3

query II
select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), [4, 5, 6]), array_position(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), [2, 3, 4]);
----
2 2

# list_position scalar function #5 (function alias `array_position`)
query III
select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1);
----
3 5 1

query III
select list_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# array_indexof scalar function #6 (function alias `array_position`)
query III
select array_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), array_indexof([1, 2, 3, 4, 5], 5), array_indexof([1, 1, 1], 1);
----
3 5 1

query III
select array_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# list_indexof scalar function #7 (function alias `array_position`)
query III
select list_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), list_indexof([1, 2, 3, 4, 5], 5), list_indexof([1, 1, 1], 1);
----
3 5 1

query III
select list_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# array_position with columns #1
query II
select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls;
Expand All @@ -2099,13 +2140,28 @@ select array_position(column1, column2), array_position(column1, column2, column
3 3
4 4

query II
select array_position(column1, column2), array_position(column1, column2, column3) from large_arrays_values_without_nulls;
----
1 1
2 2
3 3
4 4

# array_position with columns #2 (element is list)
query II
select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays;
----
3 3
2 5

#TODO: add this test when #8305 is fixed
#query II
#select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays;
#----
#3 3
#2 5

# array_position with columns and scalars #1
query III
select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls;
Expand All @@ -2115,13 +2171,28 @@ NULL NULL NULL
NULL NULL NULL
NULL NULL NULL

query III
select array_position(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_position(column1, 3), array_position(column1, 3, 5) from large_arrays_values_without_nulls;
----
1 3 NULL
NULL NULL NULL
NULL NULL NULL
NULL NULL NULL

# array_position with columns and scalars #2 (element is list)
query III
select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays;
----
NULL 6 4
NULL 1 NULL

#TODO: add this test when #8305 is fixed
#query III
#select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(List(Int64))'), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from large_nested_arrays;
#----
#NULL 6 4
#NULL 1 NULL

## array_positions (aliases: `list_positions`)

# array_positions scalar function #1
Expand Down