From d3ef6c88d886fe1d3662c316d614b9c9c5c2e4f1 Mon Sep 17 00:00:00 2001 From: parkma99 <84610851+parkma99@users.noreply.github.com> Date: Mon, 12 Jun 2023 09:17:10 +0800 Subject: [PATCH] feat: make_array support empty arguments (#6593) * feat: make_array support empty arguments * fix fmt error * fix error * array_append support empty array * array_prepend support empty make_array * array_concat support empty make_array * fix clippy * update * fix * rename `array_make` --> `make_array` --------- Co-authored-by: Andrew Lamb --- .../tests/sqllogictests/test_files/array.slt | 102 +++++++++++++++++- datafusion/expr/src/built_in_function.rs | 14 ++- .../physical-expr/src/array_expressions.rs | 82 +++++++++----- datafusion/physical-expr/src/functions.rs | 2 +- 4 files changed, 166 insertions(+), 34 deletions(-) diff --git a/datafusion/core/tests/sqllogictests/test_files/array.slt b/datafusion/core/tests/sqllogictests/test_files/array.slt index df9edce0b1df..183522138044 100644 --- a/datafusion/core/tests/sqllogictests/test_files/array.slt +++ b/datafusion/core/tests/sqllogictests/test_files/array.slt @@ -49,13 +49,49 @@ select make_array(make_array(make_array(make_array(1, 2, 3), make_array(4, 5, 6) ---- [[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]]] -# array_append scalar function +# array scalar function #6 +query ? rowsort +select make_array() +---- +[] + +# array scalar function #7 +query ?? rowsort +select make_array(make_array()), make_array(make_array(make_array())) +---- +[[]] [[[]]] + +# array_append scalar function #1 +query ? rowsort +select array_append(make_array(), 4); +---- +[4] + +# array_append scalar function #2 +query ?? rowsort +select array_append(make_array(), make_array()), array_append(make_array(), make_array(4)); +---- +[[]] [[4]] + +# array_append scalar function #3 query ??? rowsort select array_append(make_array(1, 2, 3), 4), array_append(make_array(1.0, 2.0, 3.0), 4.0), array_append(make_array('h', 'e', 'l', 'l'), 'o'); ---- [1, 2, 3, 4] [1.0, 2.0, 3.0, 4.0] [h, e, l, l, o] -# array_prepend scalar function +# array_prepend scalar function #1 +query ? rowsort +select array_prepend(4, make_array()); +---- +[4] + +# array_prepend scalar function #2 +query ?? rowsort +select array_prepend(make_array(), make_array()), array_prepend(make_array(4), make_array()); +---- +[[]] [[4]] + +# array_prepend scalar function #3 query ??? rowsort select array_prepend(1, make_array(2, 3, 4)), array_prepend(1.0, make_array(2.0, 3.0, 4.0)), array_prepend('h', make_array('e', 'l', 'l', 'o')); ---- @@ -73,6 +109,12 @@ select array_fill(1, make_array(1, 1, 1)), array_fill(2, make_array(2, 2, 2, 2, ---- [[[1]]] [[[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]], [[[[2, 2], [2, 2]], [[2, 2], [2, 2]]], [[[2, 2], [2, 2]], [[2, 2], [2, 2]]]]] +# array_fill scalar function #3 +query ? +select array_fill(1, make_array()) +---- +[] + # array_concat scalar function #1 query ?? rowsort select array_concat(make_array(1, 2, 3), make_array(4, 5, 6), make_array(7, 8, 9)), array_concat(make_array([1], [2]), make_array([3], [4])); @@ -97,6 +139,18 @@ select array_concat(make_array([[1]]), make_array([[2]])); ---- [[[1]], [[2]]] +# array_concat scalar function #5 +query ? rowsort +select array_concat(make_array(2, 3), make_array()); +---- +[2, 3] + +# array_concat scalar function #6 +query ? rowsort +select array_concat(make_array(), make_array(2, 3)); +---- +[2, 3] + # array_position scalar function #1 query III select array_position(['h', 'e', 'l', 'l', 'o'], 'l'), array_position([1, 2, 3, 4, 5], 5), array_position([1, 1, 1], 1); @@ -133,6 +187,12 @@ select array_to_string([1, 1, 1], '1'), array_to_string([[1, 2], [3, 4], [5, 6]] ---- 11111 1+2+3+4+5+6 3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3/\3 +# array_to_string scalar function #3 +query ? +select array_to_string(make_array(), ',') +---- +(empty) + # cardinality scalar function query III select cardinality(make_array(1, 2, 3, 4, 5)), cardinality([1, 3, 5]), cardinality(make_array('h', 'e', 'l', 'l', 'o')); @@ -145,7 +205,13 @@ select cardinality(make_array([1, 2], [3, 4], [5, 6])), cardinality(array_fill(3 ---- 6 18 -# trim_array scalar function +# cardinality scalar function #3 +query II +select cardinality(make_array()), cardinality(make_array(make_array())) +---- +0 0 + +# trim_array scalar function #1 query ??? select trim_array(make_array(1, 2, 3, 4, 5), 2), trim_array(['h', 'e', 'l', 'l', 'o'], 3), trim_array([1.0, 2.0, 3.0], 2); ---- @@ -157,6 +223,18 @@ select trim_array([[1, 2], [3, 4], [5, 6]], 2), trim_array(array_fill(4, [3, 4, ---- [[1, 2]] [[[4, 4], [4, 4], [4, 4], [4, 4]]] +# trim_array scalar function #3 +query ? +select array_concat(trim_array(make_array(1, 2, 3), 3), make_array(4, 5), make_array()); +---- +[4, 5] + +# trim_array scalar function #4 +query ?? +select trim_array(make_array(), 0), trim_array(make_array(), 1) +---- +[] [] + # array_length scalar function query III rowsort select array_length(make_array(1, 2, 3, 4, 5)), array_length(make_array(1, 2, 3)), array_length(make_array([1, 2], [3, 4], [5, 6])); @@ -181,6 +259,12 @@ select array_length(array_fill(3, [3, 2, 5]), 1), array_length(array_fill(3, [3, ---- 3 2 5 NULL +# array_length scalar function #5 +query III rowsort +select array_length(make_array()), array_length(make_array(), 1), array_length(make_array(), 2) +---- +0 0 NULL + # array_dims scalar function query III rowsort select array_dims(make_array(1, 2, 3)), array_dims(make_array([1, 2], [3, 4])), array_dims(make_array([[[[1], [2]]]])); @@ -193,6 +277,12 @@ select array_dims(array_fill(2, [1, 2, 3])), array_dims(array_fill(3, [2, 5, 4]) ---- [1, 2, 3] [2, 5, 4] +# array_dims scalar function #3 +query II rowsort +select array_dims(make_array()), array_dims(make_array(make_array())) +---- +[0] [1, 0] + # array_ndims scalar function query III rowsort select array_ndims(make_array(1, 2, 3)), array_ndims(make_array([1, 2], [3, 4])), array_ndims(make_array([[[[1], [2]]]])); @@ -204,3 +294,9 @@ query II rowsort select array_ndims(array_fill(1, [1, 2, 3])), array_ndims([[[[[[[[[[[[[[[[[[[[[1]]]]]]]]]]]]]]]]]]]]]); ---- 3 21 + +# array_ndims scalar function #3 +query II rowsort +select array_ndims(make_array()), array_ndims(make_array(make_array())) +---- +1 2 diff --git a/datafusion/expr/src/built_in_function.rs b/datafusion/expr/src/built_in_function.rs index 0f229059d05c..b2cf075fa57f 100644 --- a/datafusion/expr/src/built_in_function.rs +++ b/datafusion/expr/src/built_in_function.rs @@ -277,6 +277,7 @@ impl BuiltinScalarFunction { | BuiltinScalarFunction::CurrentDate | BuiltinScalarFunction::CurrentTime | BuiltinScalarFunction::Uuid + | BuiltinScalarFunction::MakeArray ) } /// Returns the [Volatility] of the builtin function. @@ -510,11 +511,14 @@ impl BuiltinScalarFunction { ))), }, BuiltinScalarFunction::Cardinality => Ok(UInt64), - BuiltinScalarFunction::MakeArray => Ok(List(Arc::new(Field::new( - "item", - input_expr_types[0].clone(), - true, - )))), + BuiltinScalarFunction::MakeArray => match input_expr_types.len() { + 0 => Ok(List(Arc::new(Field::new("item", Null, true)))), + _ => Ok(List(Arc::new(Field::new( + "item", + input_expr_types[0].clone(), + true, + )))), + }, BuiltinScalarFunction::TrimArray => match &input_expr_types[0] { List(field) => Ok(List(Arc::new(Field::new( "item", diff --git a/datafusion/physical-expr/src/array_expressions.rs b/datafusion/physical-expr/src/array_expressions.rs index 631ca376fc05..44b747082a5c 100644 --- a/datafusion/physical-expr/src/array_expressions.rs +++ b/datafusion/physical-expr/src/array_expressions.rs @@ -150,6 +150,15 @@ pub fn array(values: &[ColumnarValue]) -> Result { Ok(ColumnarValue::Array(array_array(arrays.as_slice())?)) } +/// `make_array` SQL function +pub fn make_array(values: &[ColumnarValue]) -> Result { + match values[0].data_type() { + DataType::Null => Ok(datafusion_expr::ColumnarValue::Scalar( + ScalarValue::new_list(Some(vec![]), DataType::Null), + )), + _ => array(values), + } +} macro_rules! downcast_arg { ($ARG:expr, $ARRAY_TYPE:ident) => {{ $ARG.as_any().downcast_ref::<$ARRAY_TYPE>().ok_or_else(|| { @@ -217,6 +226,7 @@ pub fn array_append(args: &[ColumnarValue]) -> Result { (DataType::UInt16, DataType::UInt16) => append!(arr, element, UInt16Array), (DataType::UInt32, DataType::UInt32) => append!(arr, element, UInt32Array), (DataType::UInt64, DataType::UInt64) => append!(arr, element, UInt64Array), + (DataType::Null, _) => return array(&args[1..]), (array_data_type, element_data_type) => { return Err(DataFusionError::NotImplemented(format!( "Array_append is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'." @@ -290,6 +300,7 @@ pub fn array_prepend(args: &[ColumnarValue]) -> Result { (DataType::UInt16, DataType::UInt16) => prepend!(arr, element, UInt16Array), (DataType::UInt32, DataType::UInt32) => prepend!(arr, element, UInt32Array), (DataType::UInt64, DataType::UInt64) => prepend!(arr, element, UInt64Array), + (DataType::Null, _) => return array(&args[..1]), (array_data_type, element_data_type) => { return Err(DataFusionError::NotImplemented(format!( "Array_prepend is not implemented for types '{array_data_type:?}' and '{element_data_type:?}'." @@ -318,30 +329,36 @@ pub fn array_concat(args: &[ColumnarValue]) -> Result { .collect(); let data_type = arrays[0].data_type(); match data_type { - DataType::List(..) => { - let list_arrays = - downcast_vec!(arrays, ListArray).collect::>>()?; - let len: usize = list_arrays.iter().map(|a| a.values().len()).sum(); - let capacity = Capacities::Array(list_arrays.iter().map(|a| a.len()).sum()); - let array_data: Vec<_> = - list_arrays.iter().map(|a| a.to_data()).collect::>(); - let array_data = array_data.iter().collect(); - let mut mutable = - MutableArrayData::with_capacities(array_data, false, capacity); - - for (i, a) in list_arrays.iter().enumerate() { - mutable.extend(i, 0, a.len()) - } + DataType::List(field) => match field.data_type() { + DataType::Null => array_concat(&args[1..]), + _ => { + let list_arrays = downcast_vec!(arrays, ListArray) + .collect::>>()?; + let len: usize = list_arrays.iter().map(|a| a.values().len()).sum(); + let capacity = + Capacities::Array(list_arrays.iter().map(|a| a.len()).sum()); + let array_data: Vec<_> = + list_arrays.iter().map(|a| a.to_data()).collect::>(); + let array_data = array_data.iter().collect(); + let mut mutable = + MutableArrayData::with_capacities(array_data, false, capacity); + + for (i, a) in list_arrays.iter().enumerate() { + mutable.extend(i, 0, a.len()) + } - let builder = mutable.into_builder(); - let list = builder - .len(1) - .buffers(vec![Buffer::from_slice_ref([0, len as i32])]) - .build() - .unwrap(); + let builder = mutable.into_builder(); + let list = builder + .len(1) + .buffers(vec![Buffer::from_slice_ref([0, len as i32])]) + .build() + .unwrap(); - return Ok(ColumnarValue::Array(Arc::new(make_array(list)))); - } + return Ok(ColumnarValue::Array(Arc::new(arrow::array::make_array( + list, + )))); + } + }, _ => Err(DataFusionError::NotImplemented(format!( "Array is not type '{data_type:?}'." ))), @@ -410,6 +427,11 @@ pub fn array_fill(args: &[ColumnarValue]) -> Result { DataType::UInt16 => fill!(array_values, element, UInt16Array), DataType::UInt32 => fill!(array_values, element, UInt32Array), DataType::UInt64 => fill!(array_values, element, UInt64Array), + DataType::Null => { + return Ok(datafusion_expr::ColumnarValue::Scalar( + ScalarValue::new_list(Some(vec![]), DataType::Null), + )) + } data_type => { return Err(DataFusionError::Internal(format!( "Array_fill is not implemented for type '{data_type:?}'." @@ -823,6 +845,7 @@ pub fn array_to_string(args: &[ColumnarValue]) -> Result { DataType::UInt16 => to_string!(arg, arr, &delimeter, UInt16Array), DataType::UInt32 => to_string!(arg, arr, &delimeter, UInt32Array), DataType::UInt64 => to_string!(arg, arr, &delimeter, UInt64Array), + DataType::Null => Ok(arg), data_type => Err(DataFusionError::NotImplemented(format!( "Array is not implemented for type '{data_type:?}'." ))), @@ -831,8 +854,13 @@ pub fn array_to_string(args: &[ColumnarValue]) -> Result { let mut arg = String::from(""); let mut res = compute_array_to_string(&mut arg, arr, delimeter.clone())?.clone(); - res.truncate(res.len() - delimeter.len()); - Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(res)))) + match res.as_str() { + "" => Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(res)))), + _ => { + res.truncate(res.len() - delimeter.len()); + Ok(ColumnarValue::Scalar(ScalarValue::Utf8(Some(res)))) + } + } } /// Trim_array SQL function @@ -871,8 +899,12 @@ pub fn trim_array(args: &[ColumnarValue]) -> Result { let list_array = downcast_arg!(arr, ListArray); let values = list_array.value(0); + if values.len() <= n { + return Ok(datafusion_expr::ColumnarValue::Scalar( + ScalarValue::new_list(Some(vec![]), DataType::Null), + )); + } let res = values.slice(0, values.len() - n); - let mut scalars = vec![]; for i in 0..res.len() { scalars.push(ColumnarValue::Scalar(ScalarValue::try_from_array(&res, i)?)); diff --git a/datafusion/physical-expr/src/functions.rs b/datafusion/physical-expr/src/functions.rs index 5a5bdf4702e0..c45986eb8a74 100644 --- a/datafusion/physical-expr/src/functions.rs +++ b/datafusion/physical-expr/src/functions.rs @@ -404,7 +404,7 @@ pub fn create_physical_fun( Arc::new(array_expressions::array_to_string) } BuiltinScalarFunction::Cardinality => Arc::new(array_expressions::cardinality), - BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::array), + BuiltinScalarFunction::MakeArray => Arc::new(array_expressions::make_array), BuiltinScalarFunction::TrimArray => Arc::new(array_expressions::trim_array), // string functions