Skip to content

Commit

Permalink
support LargeList in array_position (#8714)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weijun-H authored Jan 2, 2024
1 parent 94aff55 commit d4b96a8
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 4 deletions.
14 changes: 10 additions & 4 deletions datafusion/physical-expr/src/array_expressions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1367,8 +1367,14 @@ pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
if args.len() < 2 || args.len() > 3 {
return exec_err!("array_position expects two or three arguments");
}

let list_array = as_list_array(&args[0])?;
match &args[0].data_type() {
DataType::List(_) => general_position_dispatch::<i32>(args),
DataType::LargeList(_) => general_position_dispatch::<i64>(args),
array_type => exec_err!("array_position does not support type '{array_type:?}'."),
}
}
fn general_position_dispatch<O: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
let list_array = as_generic_list_array::<O>(&args[0])?;
let element_array = &args[1];

check_datatypes("array_position", &[list_array.values(), element_array])?;
Expand All @@ -1395,10 +1401,10 @@ pub fn array_position(args: &[ArrayRef]) -> Result<ArrayRef> {
}
}

general_position::<i32>(list_array, element_array, arr_from)
generic_position::<O>(list_array, element_array, arr_from)
}

fn general_position<OffsetSize: OffsetSizeTrait>(
fn generic_position<OffsetSize: OffsetSizeTrait>(
list_array: &GenericListArray<OffsetSize>,
element_array: &ArrayRef,
arr_from: Vec<i64>, // 0-indexed
Expand Down
71 changes: 71 additions & 0 deletions datafusion/sqllogictest/test_files/array.slt
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,17 @@ AS VALUES
(make_array(31, 32, 33, 34, 35, 26, 37, 38, 39, 40), 34, 4, 'ok', [8,9])
;

statement ok
CREATE TABLE large_arrays_values_without_nulls
AS SELECT
arrow_cast(column1, 'LargeList(Int64)') AS column1,
column2,
column3,
column4,
arrow_cast(column5, 'LargeList(Int64)') AS column5
FROM arrays_values_without_nulls
;

statement ok
CREATE TABLE arrays_range
AS VALUES
Expand Down Expand Up @@ -2054,12 +2065,22 @@ select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3,
----
3 5 1

query III
select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# array_position scalar function #2 (with optional argument)
query III
select array_position(['h', 'e', 'l', 'l', 'o'], 'l', 4), array_position([1, 2, 5, 4, 5], 5, 4), array_position([1, 1, 1], 1, 2);
----
4 5 2

query III
select array_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l', 4), array_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5, 4), array_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1, 2);
----
4 5 2

# array_position scalar function #3 (element is list)
query II
select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), [4, 5, 6]), array_position(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), [2, 3, 4]);
Expand All @@ -2072,24 +2093,44 @@ select array_position(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7,
----
4 3

query II
select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [5, 5, 5], [4, 5, 6], [7, 8, 9]), 'LargeList(List(Int64))'), [4, 5, 6]), array_position(arrow_cast(make_array([1, 3, 2], [2, 3, 4], [2, 3, 4], [5, 3, 1], [1, 3, 2]), 'LargeList(List(Int64))'), [2, 3, 4]);
----
2 2

# list_position scalar function #5 (function alias `array_position`)
query III
select list_position(['h', 'e', 'l', 'l', 'o'], 'l'), list_position([1, 2, 3, 4, 5], 5), list_position([1, 1, 1], 1);
----
3 5 1

query III
select list_position(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_position(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_position(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# array_indexof scalar function #6 (function alias `array_position`)
query III
select array_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), array_indexof([1, 2, 3, 4, 5], 5), array_indexof([1, 1, 1], 1);
----
3 5 1

query III
select array_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), array_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), array_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# list_indexof scalar function #7 (function alias `array_position`)
query III
select list_indexof(['h', 'e', 'l', 'l', 'o'], 'l'), list_indexof([1, 2, 3, 4, 5], 5), list_indexof([1, 1, 1], 1);
----
3 5 1

query III
select list_indexof(arrow_cast(['h', 'e', 'l', 'l', 'o'], 'LargeList(Utf8)'), 'l'), list_indexof(arrow_cast([1, 2, 3, 4, 5], 'LargeList(Int64)'), 5), list_indexof(arrow_cast([1, 1, 1], 'LargeList(Int64)'), 1);
----
3 5 1

# array_position with columns #1
query II
select array_position(column1, column2), array_position(column1, column2, column3) from arrays_values_without_nulls;
Expand All @@ -2099,13 +2140,28 @@ select array_position(column1, column2), array_position(column1, column2, column
3 3
4 4

query II
select array_position(column1, column2), array_position(column1, column2, column3) from large_arrays_values_without_nulls;
----
1 1
2 2
3 3
4 4

# array_position with columns #2 (element is list)
query II
select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays;
----
3 3
2 5

#TODO: add this test when #8305 is fixed
#query II
#select array_position(column1, column2), array_position(column1, column2, column3) from nested_arrays;
#----
#3 3
#2 5

# array_position with columns and scalars #1
query III
select array_position(make_array(1, 2, 3, 4, 5), column2), array_position(column1, 3), array_position(column1, 3, 5) from arrays_values_without_nulls;
Expand All @@ -2115,13 +2171,28 @@ NULL NULL NULL
NULL NULL NULL
NULL NULL NULL

query III
select array_position(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), column2), array_position(column1, 3), array_position(column1, 3, 5) from large_arrays_values_without_nulls;
----
1 3 NULL
NULL NULL NULL
NULL NULL NULL
NULL NULL NULL

# array_position with columns and scalars #2 (element is list)
query III
select array_position(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from nested_arrays;
----
NULL 6 4
NULL 1 NULL

#TODO: add this test when #8305 is fixed
#query III
#select array_position(arrow_cast(make_array([1, 2, 3], [4, 5, 6], [11, 12, 13]), 'LargeList(List(Int64))'), column2), array_position(column1, make_array(4, 5, 6)), array_position(column1, make_array(1, 2, 3), 2) from large_nested_arrays;
#----
#NULL 6 4
#NULL 1 NULL

## array_positions (aliases: `list_positions`)

# array_positions scalar function #1
Expand Down

0 comments on commit d4b96a8

Please sign in to comment.