Skip to content

Commit 8012c4d

Browse files
authored
Column support for array concat (#6879)
* first draft Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * use old concat func Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * merge main and add tests Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * support nulls Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add tests Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * cleanup Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * add more failed tests Signed-off-by: jayzhan211 <jayzhan211@gmail.com> * update tests Signed-off-by: jayzhan211 <jayzhan211@gmail.com> --------- Signed-off-by: jayzhan211 <jayzhan211@gmail.com>
1 parent 8a1e526 commit 8012c4d

File tree

2 files changed

+216
-28
lines changed

2 files changed

+216
-28
lines changed

datafusion/core/tests/sqllogictests/test_files/array.slt

Lines changed: 164 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,17 @@ AS VALUES
6868
(make_array(61, 62, 63, 64, 65, 66, 67, 68, 69, 70), 66, 7, NULL)
6969
;
7070

71+
statement ok
72+
CREATE TABLE arrays_values_v2
73+
AS VALUES
74+
(make_array(NULL, 2, 3), make_array(4, 5, NULL), 12, make_array([30, 40, 50])),
75+
(NULL, make_array(7, NULL, 8), 13, make_array(make_array(NULL,NULL,60))),
76+
(make_array(9, NULL, 10), NULL, 14, make_array(make_array(70,NULL,NULL))),
77+
(make_array(NULL, 1), make_array(NULL, 21), NULL, NULL),
78+
(make_array(11, 12), NULL, NULL, NULL),
79+
(NULL, NULL, NULL, NULL)
80+
;
81+
7182
statement ok
7283
CREATE TABLE arrays_values_without_nulls
7384
AS VALUES
@@ -116,6 +127,16 @@ NULL 44 5 @
116127
[51, 52, , 54, 55, 56, 57, 58, 59, 60] 55 NULL ^
117128
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70] 66 7 NULL
118129

130+
query ??I?
131+
select column1, column2, column3, column4 from arrays_values_v2;
132+
----
133+
[, 2, 3] [4, 5, ] 12 [[30, 40, 50]]
134+
NULL [7, , 8] 13 [[, , 60]]
135+
[9, , 10] NULL 14 [[70, , ]]
136+
[, 1] [, 21] NULL NULL
137+
[11, 12] NULL NULL NULL
138+
NULL NULL NULL NULL
139+
119140
# arrays_values_without_nulls table
120141
query ?II
121142
select column1, column2, column3 from arrays_values_without_nulls;
@@ -423,6 +444,148 @@ select array_concat(make_array(10, 20), make_array([30, 40]), make_array([[50, 6
423444
----
424445
[[[10, 20]], [[30, 40]], [[50, 60]]]
425446

447+
# array_concat column-wise #1
448+
query ?
449+
select array_concat(column1, make_array(0)) from arrays_values_without_nulls;
450+
----
451+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0]
452+
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 0]
453+
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 0]
454+
[31, 32, 33, 34, 35, 26, 37, 38, 39, 40, 0]
455+
456+
# array_concat column-wise #2
457+
query ?
458+
select array_concat(column1, column1) from arrays_values_without_nulls;
459+
----
460+
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
461+
[11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
462+
[21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30]
463+
[31, 32, 33, 34, 35, 26, 37, 38, 39, 40, 31, 32, 33, 34, 35, 26, 37, 38, 39, 40]
464+
465+
# array_concat column-wise #3
466+
query ?
467+
select array_concat(make_array(column2), make_array(column3)) from arrays_values_without_nulls;
468+
----
469+
[1, 1]
470+
[12, 2]
471+
[23, 3]
472+
[34, 4]
473+
474+
# array_concat column-wise #4
475+
query ?
476+
select array_concat(column1, column2) from arrays_values;
477+
----
478+
[, 2, 3, 4, 5, 6, 7, 8, 9, 10, 1]
479+
[11, 12, 13, 14, 15, 16, 17, 18, , 20, 12]
480+
[21, 22, 23, , 25, 26, 27, 28, 29, 30, 23]
481+
[31, 32, 33, 34, 35, , 37, 38, 39, 40, 34]
482+
[44]
483+
[41, 42, 43, 44, 45, 46, 47, 48, 49, 50, ]
484+
[51, 52, , 54, 55, 56, 57, 58, 59, 60, 55]
485+
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70, 66]
486+
487+
# array_concat column-wise #5
488+
query ?
489+
select array_concat(make_array(column2), make_array(0)) from arrays_values;
490+
----
491+
[1, 0]
492+
[12, 0]
493+
[23, 0]
494+
[34, 0]
495+
[44, 0]
496+
[, 0]
497+
[55, 0]
498+
[66, 0]
499+
500+
# array_concat column-wise #6
501+
query ???
502+
select array_concat(column1, column1), array_concat(column2, column2), array_concat(column3, column3) from arrays;
503+
----
504+
[[, 2], [3, ], [, 2], [3, ]] [1.1, 2.2, 3.3, 1.1, 2.2, 3.3] [L, o, r, e, m, L, o, r, e, m]
505+
[[3, 4], [5, 6], [3, 4], [5, 6]] [, 5.5, 6.6, , 5.5, 6.6] [i, p, , u, m, i, p, , u, m]
506+
[[5, 6], [7, 8], [5, 6], [7, 8]] [7.7, 8.8, 9.9, 7.7, 8.8, 9.9] [d, , l, o, r, d, , l, o, r]
507+
[[7, ], [9, 10], [7, ], [9, 10]] [10.1, , 12.2, 10.1, , 12.2] [s, i, t, s, i, t]
508+
NULL [13.3, 14.4, 15.5, 13.3, 14.4, 15.5] [a, m, e, t, a, m, e, t]
509+
[[11, 12], [13, 14], [11, 12], [13, 14]] NULL [,, ,]
510+
[[15, 16], [, 18], [15, 16], [, 18]] [16.6, 17.7, 18.8, 16.6, 17.7, 18.8] NULL
511+
512+
# array_concat column-wise #7
513+
query ??
514+
select array_concat(column1, make_array(make_array(1, 2), make_array(3, 4))), array_concat(column2, make_array(1.1, 2.2, 3.3)) from arrays;
515+
----
516+
[[, 2], [3, ], [1, 2], [3, 4]] [1.1, 2.2, 3.3, 1.1, 2.2, 3.3]
517+
[[3, 4], [5, 6], [1, 2], [3, 4]] [, 5.5, 6.6, 1.1, 2.2, 3.3]
518+
[[5, 6], [7, 8], [1, 2], [3, 4]] [7.7, 8.8, 9.9, 1.1, 2.2, 3.3]
519+
[[7, ], [9, 10], [1, 2], [3, 4]] [10.1, , 12.2, 1.1, 2.2, 3.3]
520+
[[1, 2], [3, 4]] [13.3, 14.4, 15.5, 1.1, 2.2, 3.3]
521+
[[11, 12], [13, 14], [1, 2], [3, 4]] [1.1, 2.2, 3.3]
522+
[[15, 16], [, 18], [1, 2], [3, 4]] [16.6, 17.7, 18.8, 1.1, 2.2, 3.3]
523+
524+
# array_concat column-wise #8
525+
query ?
526+
select array_concat(column3, make_array('.', '.', '.')) from arrays;
527+
----
528+
[L, o, r, e, m, ., ., .]
529+
[i, p, , u, m, ., ., .]
530+
[d, , l, o, r, ., ., .]
531+
[s, i, t, ., ., .]
532+
[a, m, e, t, ., ., .]
533+
[,, ., ., .]
534+
[., ., .]
535+
536+
# query ??I?
537+
# select column1, column2, column3, column4 from arrays_values_v2;
538+
# ----
539+
# [, 2, 3] [4, 5, ] 12 [[30, 40, 50]]
540+
# NULL [7, , 8] 13 [[, , 60]]
541+
# [9, , 10] NULL 14 [[70, , ]]
542+
# [, 1] [, 21] NULL NULL
543+
# [11, 12] NULL NULL NULL
544+
# NULL NULL NULL NULL
545+
546+
# array_concat column-wise #9 (1D + 1D)
547+
query ?
548+
select array_concat(column1, column2) from arrays_values_v2;
549+
----
550+
[, 2, 3, 4, 5, ]
551+
[7, , 8]
552+
[9, , 10]
553+
[, 1, , 21]
554+
[11, 12]
555+
NULL
556+
557+
# TODO: Concat columns with different dimensions fails
558+
# array_concat column-wise #10 (1D + 2D)
559+
# query error DataFusion error: Arrow error: Invalid argument error: column types must match schema types, expected List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) but found List\(Field \{ name: "item", data_type: List\(Field \{ name: "item", data_type: Int64, nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\), nullable: true, dict_id: 0, dict_is_ordered: false, metadata: \{\} \}\) at column index 0
560+
# select array_concat(make_array(column3), column4) from arrays_values_v2;
561+
562+
# array_concat column-wise #11 (1D + Integers)
563+
query ?
564+
select array_concat(column2, column3) from arrays_values_v2;
565+
----
566+
[4, 5, , 12]
567+
[7, , 8, 13]
568+
[14]
569+
[, 21, ]
570+
[]
571+
[]
572+
573+
# TODO: Panic at 'range end index 3 out of range for slice of length 2'
574+
# array_concat column-wise #12 (2D + 1D)
575+
# query
576+
# select array_concat(column4, column1) from arrays_values_v2;
577+
578+
# array_concat column-wise #13 (1D + 1D + 1D)
579+
query ?
580+
select array_concat(make_array(column3), column1, column2) from arrays_values_v2;
581+
----
582+
[12, , 2, 3, 4, 5, ]
583+
[13, 7, , 8]
584+
[14, 9, , 10]
585+
[, , 1, , 21]
586+
[, 11, 12]
587+
[]
588+
426589
## array_position
427590

428591
# array_position scalar function #1
@@ -835,6 +998,7 @@ select make_array(f0) from fixed_size_list_array
835998
[[1, 2], [3, 4]]
836999

8371000

1001+
8381002
### Delete tables
8391003

8401004

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 52 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ use arrow::array::*;
2121
use arrow::buffer::{Buffer, OffsetBuffer};
2222
use arrow::compute;
2323
use arrow::datatypes::{DataType, Field, UInt64Type};
24+
use arrow_buffer::NullBuffer;
2425
use core::any::type_name;
2526
use datafusion_common::cast::{as_generic_string_array, as_int64_array, as_list_array};
2627
use datafusion_common::ScalarValue;
@@ -554,42 +555,65 @@ fn align_array_dimensions(args: Vec<ArrayRef>) -> Result<Vec<ArrayRef>> {
554555
aligned_args
555556
}
556557

557-
/// Array_concat/Array_cat SQL function
558-
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
559-
match args[0].data_type() {
560-
DataType::List(field) => match field.data_type() {
561-
DataType::Null => array_concat(&args[1..]),
562-
_ => {
563-
let args = align_array_dimensions(args.to_vec())?;
558+
fn concat_internal(args: &[ArrayRef]) -> Result<ArrayRef> {
559+
let args = align_array_dimensions(args.to_vec())?;
564560

565-
let list_arrays = downcast_vec!(args, ListArray)
566-
.collect::<Result<Vec<&ListArray>>>()?;
561+
let list_arrays =
562+
downcast_vec!(args, ListArray).collect::<Result<Vec<&ListArray>>>()?;
567563

568-
let len: usize = list_arrays.iter().map(|a| a.values().len()).sum();
564+
// Assume number of rows is the same for all arrays
565+
let row_count = list_arrays[0].len();
566+
let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum());
567+
let array_data: Vec<_> = list_arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
568+
let array_data: Vec<&ArrayData> = array_data.iter().collect();
569569

570-
let capacity =
571-
Capacities::Array(list_arrays.iter().map(|a| a.len()).sum());
572-
let array_data: Vec<_> =
573-
list_arrays.iter().map(|a| a.to_data()).collect::<Vec<_>>();
570+
let mut mutable = MutableArrayData::with_capacities(array_data, true, capacity);
574571

575-
let array_data = array_data.iter().collect();
572+
let mut array_lens = vec![0; row_count];
573+
let mut null_bit_map: Vec<bool> = vec![true; row_count];
576574

577-
let mut mutable =
578-
MutableArrayData::with_capacities(array_data, false, capacity);
575+
for (i, array_len) in array_lens.iter_mut().enumerate().take(row_count) {
576+
let null_count = mutable.null_count();
577+
for (j, a) in list_arrays.iter().enumerate() {
578+
mutable.extend(j, i, i + 1);
579+
*array_len += a.value_length(i);
580+
}
579581

580-
for (i, a) in list_arrays.iter().enumerate() {
581-
mutable.extend(i, 0, a.len())
582-
}
582+
// This means all arrays are null
583+
if mutable.null_count() == null_count + list_arrays.len() {
584+
null_bit_map[i] = false;
585+
}
586+
}
583587

584-
let builder = mutable.into_builder();
585-
let list = builder
586-
.len(1)
587-
.buffers(vec![Buffer::from_slice_ref([0, len as i32])])
588-
.build()
589-
.unwrap();
588+
let mut buffer = BooleanBufferBuilder::new(row_count);
589+
buffer.append_slice(null_bit_map.as_slice());
590+
let nulls = Some(NullBuffer::from(buffer.finish()));
590591

591-
return Ok(Arc::new(arrow::array::make_array(list)));
592-
}
592+
let offsets: Vec<i32> = std::iter::once(0)
593+
.chain(array_lens.iter().scan(0, |state, &x| {
594+
*state += x;
595+
Some(*state)
596+
}))
597+
.collect();
598+
599+
let builder = mutable.into_builder();
600+
601+
let list = builder
602+
.len(row_count)
603+
.buffers(vec![Buffer::from_vec(offsets)])
604+
.nulls(nulls)
605+
.build()?;
606+
607+
let list = arrow::array::make_array(list);
608+
Ok(Arc::new(list))
609+
}
610+
611+
/// Array_concat/Array_cat SQL function
612+
pub fn array_concat(args: &[ArrayRef]) -> Result<ArrayRef> {
613+
match args[0].data_type() {
614+
DataType::List(field) => match field.data_type() {
615+
DataType::Null => array_concat(&args[1..]),
616+
_ => concat_internal(args),
593617
},
594618
data_type => Err(DataFusionError::NotImplemented(format!(
595619
"Array is not type '{data_type:?}'."

0 commit comments

Comments
 (0)